Пример #1
0
def precis(posterior, corr=False, digits=2):
    if isinstance(posterior, TracePosterior):
        node_supports = extract_samples(posterior)
    else:
        node_supports = posterior
    df = pd.DataFrame(columns=["Mean", "StdDev", "|0.89", "0.89|"])
    for node, support in node_supports.items():
        if support.dim() == 1:
            hpdi = stats.hpdi(support, prob=0.89)
            df.loc[node] = [support.mean().item(), support.std().item(),
                            hpdi[0].item(), hpdi[1].item()]
        else:
            support = support.reshape(support.size(0), -1)
            mean = support.mean(0)
            std = support.std(0)
            hpdi = stats.hpdi(support, prob=0.89)
            for i in range(mean.size(0)):
                df.loc["{}[{}]".format(node, i)] = [mean[i].item(), std[i].item(),
                                                    hpdi[0, i].item(), hpdi[1, i].item()]
    # correct `intercept` due to "centering trick"
    if isinstance(posterior, LM) and "Intercept" in df.index and posterior.centering:
        center = posterior._get_centering_constant(df["Mean"].to_dict()).item()
        df.loc["Intercept", ["Mean", "|0.89", "0.89|"]] -= center

    if corr:
        cov = vcov(posterior)
        corr = cov / cov.diag().ger(cov.diag()).sqrt()
        for i, node in enumerate(df.index):
            df[node] = corr[:, i]

    if isinstance(posterior, MCMC):
        diagnostics = posterior.marginal(df.index.tolist()).diagnostics()
        df = pd.concat([df, pd.DataFrame(diagnostics).T.astype(float)], axis=1)

    return df.round(digits)
Пример #2
0
def summary(samples, prob=0.9, group_by_chain=True):
    """
    Returns a summary table displaying diagnostics of ``samples`` from the
    posterior. The diagnostics displayed are mean, standard deviation, median,
    the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`,
    :func:`~pyro.ops.stats.split_gelman_rubin`.

    :param dict samples: dictionary of samples keyed by site name.
    :param float prob: the probability mass of samples within the credibility interval.
    :param bool group_by_chain: If True, each variable in `samples`
        will be treated as having shape `num_chains x num_samples x sample_shape`.
        Otherwise, the corresponding shape will be `num_samples x sample_shape`
        (i.e. without chain dimension).
    """
    if not group_by_chain:
        samples = {k: v.unsqueeze(0) for k, v in samples.items()}

    summary_dict = {}
    for name, value in samples.items():
        value_flat = torch.reshape(value, (-1,) + value.shape[2:])
        mean = value_flat.mean(dim=0)
        std = value_flat.std(dim=0)
        median = value_flat.median(dim=0)[0]
        hpdi = stats.hpdi(value_flat, prob=prob)
        n_eff = _safe(stats.effective_sample_size)(value)
        r_hat = stats.split_gelman_rubin(value)
        hpd_lower = '{:.1f}%'.format(50 * (1 - prob))
        hpd_upper = '{:.1f}%'.format(50 * (1 + prob))
        summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median),
                                          (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]),
                                          ("n_eff", n_eff), ("r_hat", r_hat)])
    return summary_dict
Пример #3
0
def summary(samples, prob=0.9, group_by_chain=True):
    """
    Prints a summary table displaying diagnostics of ``samples`` from the
    posterior. The diagnostics displayed are mean, standard deviation, median,
    the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`,
    :func:`~pyro.ops.stats.split_gelman_rubin`.

    :param dict samples: dictionary of samples keyed by site name.
    :param float prob: the probability mass of samples within the credibility interval.
    :param bool group_by_chain: If True, each variable in `samples`
        will be treated as having shape `num_chains x num_samples x sample_shape`.
        Otherwise, the corresponding shape will be `num_samples x sample_shape`
        (i.e. without chain dimension).
    """
    if not group_by_chain:
        samples = {k: v.unsqueeze(0) for k, v in samples.items()}

    row_names = {
        k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']'
        for k, v in samples.items()
    }
    max_len = max(max(map(lambda x: len(x), row_names.values())), 10)
    name_format = '{:>' + str(max_len) + '}'
    header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}'
    columns = [
        '', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)),
        '{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat'
    ]
    print('\n')
    print(header_format.format(*columns))

    row_format = name_format + ' {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f}'
    for name, value in samples.items():
        value_flat = torch.reshape(value, (-1, ) + value.shape[2:])
        mean = value_flat.mean(dim=0)
        sd = value_flat.std(dim=0)
        median = value_flat.median(dim=0)[0]
        hpd = stats.hpdi(value_flat, prob=prob)
        n_eff = _safe(stats.effective_sample_size)(value)
        r_hat = stats.split_gelman_rubin(value)
        shape = value_flat.shape[1:]
        if len(shape) == 0:
            print(
                row_format.format(name, mean, sd, median, hpd[0], hpd[1],
                                  n_eff, r_hat))
        else:
            for idx in product(*map(range, shape)):
                idx_str = '[{}]'.format(','.join(map(str, idx)))
                print(
                    row_format.format(name + idx_str, mean[idx], sd[idx],
                                      median[idx], hpd[0][idx], hpd[1][idx],
                                      n_eff[idx], r_hat[idx]))
    print('\n')
Пример #4
0
def get_hpdi_confidence_interval(samples, probs):
    pi = stats.hpdi(samples, prob=probs)
    below_interval_1 = (samples < pi[0].item()).sum().float() / len(samples)
    below_interval_2 = (samples < pi[1].item()).sum().float() / len(samples)

    df = {
        "LPI":
        [round(below_interval_1.item(), 2) * 100,
         round(pi[0].item(), 2)],
        "UPI": [
            100 - round(below_interval_1.item(), 2) * 100,
            round(pi[1].item(), 2)
        ]
    }
    df = pd.DataFrame.from_dict(df)
    return df
Пример #5
0
def precis(posterior, corr=False, digits=2):
    if isinstance(posterior, TracePosterior):
        node_supports = posterior.marginal(
            posterior.exec_traces[0].stochastic_nodes).support()
    else:
        node_supports = posterior
    mean, std_dev, lower_hpd, upper_hpd = {}, {}, {}, {}
    for node, support in node_supports.items():
        mean[node] = support.mean().item()
        std_dev[node] = support.std().item()
        hpdi = stats.hpdi(support, prob=0.89)
        lower_hpd[node] = hpdi[0].item()
        upper_hpd[node] = hpdi[1].item()
    # correct `intercept` due to "centering trick"
    if isinstance(posterior, LM):
        center = posterior._get_centering_constant(mean)
        mean["intercept"] = mean["intercept"] - center
        lower_hpd["intercept"] = lower_hpd["intercept"] - center
        upper_hpd["intercept"] = upper_hpd["intercept"] - center
    precis = pd.DataFrame.from_dict({
        "Mean": mean,
        "StdDev": std_dev,
        "|0.89": lower_hpd,
        "0.89|": upper_hpd
    })

    if corr:
        cov = vcov(posterior)
        corr = cov / cov.diag().ger(cov.diag()).sqrt()
        corr_dict = {}
        pos = 0
        for node in posterior.exec_traces[0].stochastic_nodes:
            corr_dict[node] = corr[:, pos].tolist()
            pos = pos + 1
        precis = precis.assign(**corr_dict)
    return precis.round(digits)
Пример #6
0
def _hpdi(x, dim=0):
    return hpdi(x, prob=0.8, dim=dim)
Пример #7
0
def test_hpdi():
    x = torch.randn(20000)
    assert_equal(hpdi(x, prob=0.8), pi(x, prob=0.8), prec=0.01)

    x = torch.empty(20000).exponential_(1)
    assert_equal(hpdi(x, prob=0.2), torch.tensor([0.0, 0.22]), prec=0.01)
Пример #8
0
def save_stats(model, path, CI=0.95, save_matlab=False):
    # global parameters
    global_params = model._global_params
    summary = pd.DataFrame(
        index=global_params,
        columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"],
    )

    logger.info("- credible intervals")
    ci_stats = compute_ci(model, model.ci_params, CI)

    for param in global_params:
        if ci_stats[param]["Mean"].ndim == 0:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item()
        else:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].tolist()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].tolist()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].tolist()
    logger.info("- spot probabilities")
    ci_stats["m_probs"] = model.m_probs.data.cpu()
    ci_stats["theta_probs"] = model.theta_probs.data.cpu()
    ci_stats["z_probs"] = model.z_probs.data.cpu()
    ci_stats["z_map"] = model.z_map.data.cpu()
    ci_stats["p_specific"] = ci_stats["theta_probs"].sum(0)

    # calculate vmin/vmax
    theta_mask = ci_stats["theta_probs"] > 0.5
    if theta_mask.sum():
        hmax = np.percentile(ci_stats["height"]["Mean"][theta_mask], 99)
    else:
        hmax = 1
    ci_stats["height"]["vmin"] = -0.03 * hmax
    ci_stats["height"]["vmax"] = 1.3 * hmax
    ci_stats["width"]["vmin"] = 0.5
    ci_stats["width"]["vmax"] = 2.5
    ci_stats["x"]["vmin"] = -9
    ci_stats["x"]["vmax"] = 9
    ci_stats["y"]["vmin"] = -9
    ci_stats["y"]["vmax"] = 9
    bmax = np.percentile(ci_stats["background"]["Mean"].flatten(), 99)
    ci_stats["background"]["vmin"] = -0.03 * bmax
    ci_stats["background"]["vmax"] = 1.3 * bmax

    # timestamps
    if model.data.time1 is not None:
        ci_stats["time1"] = model.data.time1
    if model.data.ttb is not None:
        ci_stats["ttb"] = model.data.ttb

    # intensity of target-specific spots
    theta_mask = torch.argmax(ci_stats["theta_probs"], dim=0)
    #  h_specific = Vindex(ci_stats["height"]["Mean"])[
    #      theta_mask, torch.arange(model.data.Nt)[:, None], torch.arange(model.data.F)
    #  ]
    #  ci_stats["h_specific"] = h_specific * (ci_stats["z_map"] > 0).long()

    model.params = ci_stats

    logger.info("- SNR and Chi2-test")
    # snr and chi2 test
    snr = torch.zeros(model.K,
                      model.data.Nt,
                      model.data.F,
                      model.Q,
                      device=torch.device("cpu"))
    chi2 = torch.zeros(model.data.Nt,
                       model.data.F,
                       model.Q,
                       device=torch.device("cpu"))
    for n in range(model.data.Nt):
        snr[:, n], chi2[n] = snr_and_chi2(
            model.data.images[n],
            ci_stats["height"]["Mean"][:, n],
            ci_stats["width"]["Mean"][:, n],
            ci_stats["x"]["Mean"][:, n],
            ci_stats["y"]["Mean"][:, n],
            model.data.xy[n],
            ci_stats["background"]["Mean"][n],
            ci_stats["gain"]["Mean"],
            model.data.offset.mean,
            model.data.offset.var,
            model.data.P,
            ci_stats["theta_probs"][:, n],
        )
    snr_masked = snr[ci_stats["theta_probs"] > 0.5]
    summary.loc["SNR", "Mean"] = snr_masked.mean().item()
    ci_stats["chi2"] = {}
    ci_stats["chi2"]["values"] = chi2
    cmax = quantile(ci_stats["chi2"]["values"].flatten(), 0.99)
    ci_stats["chi2"]["vmin"] = -0.03 * cmax
    ci_stats["chi2"]["vmax"] = 1.3 * cmax

    # classification statistics
    if model.data.labels is not None:
        pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel()
        true_labels = model.data.labels["z"][:model.data.N].ravel()

        with np.errstate(divide="ignore", invalid="ignore"):
            summary.loc["MCC",
                        "Mean"] = matthews_corrcoef(true_labels, pred_labels)
        summary.loc["Recall", "Mean"] = recall_score(true_labels,
                                                     pred_labels,
                                                     zero_division=0)
        summary.loc["Precision", "Mean"] = precision_score(true_labels,
                                                           pred_labels,
                                                           zero_division=0)

        (
            summary.loc["TN", "Mean"],
            summary.loc["FP", "Mean"],
            summary.loc["FN", "Mean"],
            summary.loc["TP", "Mean"],
        ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel()

        mask = torch.from_numpy(model.data.labels["z"][:model.data.N]) > 0
        samples = torch.masked_select(
            model.z_probs[model.data.is_ontarget].cpu(), mask)
        if len(samples):
            z_ll, z_ul = hpdi(samples, CI)
            summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item()
            summary.loc["p(specific)", "95% LL"] = z_ll.item()
            summary.loc["p(specific)", "95% UL"] = z_ul.item()
        else:
            summary.loc["p(specific)", "Mean"] = 0.0
            summary.loc["p(specific)", "95% LL"] = 0.0
            summary.loc["p(specific)", "95% UL"] = 0.0

    model.summary = summary

    if path is not None:
        path = Path(path)
        param_path = path / f"{model.name}-params.tpqr"
        torch.save(ci_stats, param_path)
        logger.info(f"Parameters were saved in {param_path}")
        if save_matlab:
            from scipy.io import savemat

            for param, field in ci_stats.items():
                if param in (
                        "m_probs",
                        "theta_probs",
                        "z_probs",
                        "z_map",
                        "p_specific",
                        "h_specific",
                        "time1",
                        "ttb",
                ):
                    ci_stats[param] = np.asarray(field)
                    continue
                for stat, value in field.items():
                    ci_stats[param][stat] = np.asarray(value)
            mat_path = path / f"{model.name}-params.mat"
            savemat(mat_path, ci_stats)
            logger.info(f"Matlab parameters were saved in {mat_path}")
        csv_path = path / f"{model.name}-summary.csv"
        summary.to_csv(csv_path)
        logger.info(f"Summary statistics were saved in {csv_path}")
Пример #9
0
def save_stats(model, path, CI=0.95, save_matlab=False):
    # global parameters
    global_params = model._global_params
    summary = pd.DataFrame(
        index=global_params,
        columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"],
    )
    # local parameters
    local_params = [
        "height",
        "width",
        "x",
        "y",
        "background",
    ]

    ci_stats = defaultdict(partial(defaultdict, list))
    num_samples = 10000
    for param in global_params:
        if param == "gain":
            fn = dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            )
        elif param == "pi":
            fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size"))
        elif param == "lamda":
            fn = dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            )
        elif param == "proximity":
            fn = AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (model.data.P + 1) / math.sqrt(12),
            )
        elif param == "trans":
            fn = dist.Dirichlet(
                pyro.param("trans_mean") * pyro.param("trans_size")
            ).to_event(1)
        else:
            raise NotImplementedError
        samples = fn.sample((num_samples,)).data.squeeze()
        ci_stats[param] = {}
        LL, UL = hpdi(
            samples,
            CI,
            dim=0,
        )
        ci_stats[param]["LL"] = LL.cpu()
        ci_stats[param]["UL"] = UL.cpu()
        ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu()

        # calculate Keq
        if param == "pi":
            ci_stats["Keq"] = {}
            LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0)
            ci_stats["Keq"]["LL"] = LL.cpu()
            ci_stats["Keq"]["UL"] = UL.cpu()
            ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu()

    # this does not need to be very accurate
    num_samples = 1000
    for param in local_params:
        LL, UL, Mean = [], [], []
        for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size):
            ndx = ndx[:, None]
            kdx = torch.arange(model.K)[:, None, None]
            ll, ul, mean = [], [], []
            for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size):
                if param == "background":
                    fn = dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx]
                        * Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    )
                elif param == "height":
                    fn = dist.Gamma(
                        Vindex(pyro.param("h_loc"))[kdx, ndx, fdx]
                        * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                        Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                    )
                elif param == "width":
                    fn = AffineBeta(
                        Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                        0.75,
                        2.25,
                    )
                elif param == "x":
                    fn = AffineBeta(
                        Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                elif param == "y":
                    fn = AffineBeta(
                        Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                else:
                    raise NotImplementedError
                samples = fn.sample((num_samples,)).data
                l, u = hpdi(
                    samples,
                    CI,
                    dim=0,
                )
                m = fn.mean.data
                ll.append(l)
                ul.append(u)
                mean.append(m)
            else:
                LL.append(torch.cat(ll, -1))
                UL.append(torch.cat(ul, -1))
                Mean.append(torch.cat(mean, -1))
        else:
            ci_stats[param]["LL"] = torch.cat(LL, -2).cpu()
            ci_stats[param]["UL"] = torch.cat(UL, -2).cpu()
            ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu()

    for param in global_params:
        if param == "pi":
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item()
            # Keq
            summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item()
            summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item()
            summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item()
        elif param == "trans":
            summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item()
            summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item()
            summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item()
            summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item()
            summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item()
            summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item()
        else:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item()
    ci_stats["m_probs"] = model.m_probs.data.cpu()
    ci_stats["theta_probs"] = model.theta_probs.data.cpu()
    ci_stats["z_probs"] = model.z_probs.data.cpu()
    ci_stats["z_map"] = model.z_map.data.cpu()

    # timestamps
    if model.data.time1 is not None:
        ci_stats["time1"] = model.data.time1
    if model.data.ttb is not None:
        ci_stats["ttb"] = model.data.ttb

    model.params = ci_stats

    # snr
    summary.loc["SNR", "Mean"] = (
        snr(
            model.data.images[:, :, model.cdx],
            ci_stats["width"]["Mean"],
            ci_stats["x"]["Mean"],
            ci_stats["y"]["Mean"],
            model.data.xy[:, :, model.cdx],
            ci_stats["background"]["Mean"],
            ci_stats["gain"]["Mean"],
            model.data.offset.mean,
            model.data.offset.var,
            model.data.P,
            model.theta_probs,
        )
        .mean()
        .item()
    )

    # classification statistics
    if model.data.labels is not None:
        pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel()
        true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel()

        with np.errstate(divide="ignore", invalid="ignore"):
            summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels)
        summary.loc["Recall", "Mean"] = recall_score(
            true_labels, pred_labels, zero_division=0
        )
        summary.loc["Precision", "Mean"] = precision_score(
            true_labels, pred_labels, zero_division=0
        )

        (
            summary.loc["TN", "Mean"],
            summary.loc["FP", "Mean"],
            summary.loc["FN", "Mean"],
            summary.loc["TP", "Mean"],
        ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel()

        mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx])
        samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask)
        if len(samples):
            z_ll, z_ul = hpdi(samples, CI)
            summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item()
            summary.loc["p(specific)", "95% LL"] = z_ll.item()
            summary.loc["p(specific)", "95% UL"] = z_ul.item()
        else:
            summary.loc["p(specific)", "Mean"] = 0.0
            summary.loc["p(specific)", "95% LL"] = 0.0
            summary.loc["p(specific)", "95% UL"] = 0.0

    model.summary = summary

    if path is not None:
        path = Path(path)
        torch.save(ci_stats, path / f"{model.full_name}-params.tpqr")
        if save_matlab:
            from scipy.io import savemat

            for param, field in ci_stats.items():
                if param in (
                    "m_probs",
                    "theta_probs",
                    "z_probs",
                    "z_map",
                    "time1",
                    "ttb",
                ):
                    ci_stats[param] = field.numpy()
                    continue
                for stat, value in field.items():
                    ci_stats[param][stat] = value.cpu().numpy()
            savemat(path / f"{model.full_name}-params.mat", ci_stats)
        summary.to_csv(
            path / f"{model.full_name}-summary.csv",
        )