def precis(posterior, corr=False, digits=2): if isinstance(posterior, TracePosterior): node_supports = extract_samples(posterior) else: node_supports = posterior df = pd.DataFrame(columns=["Mean", "StdDev", "|0.89", "0.89|"]) for node, support in node_supports.items(): if support.dim() == 1: hpdi = stats.hpdi(support, prob=0.89) df.loc[node] = [support.mean().item(), support.std().item(), hpdi[0].item(), hpdi[1].item()] else: support = support.reshape(support.size(0), -1) mean = support.mean(0) std = support.std(0) hpdi = stats.hpdi(support, prob=0.89) for i in range(mean.size(0)): df.loc["{}[{}]".format(node, i)] = [mean[i].item(), std[i].item(), hpdi[0, i].item(), hpdi[1, i].item()] # correct `intercept` due to "centering trick" if isinstance(posterior, LM) and "Intercept" in df.index and posterior.centering: center = posterior._get_centering_constant(df["Mean"].to_dict()).item() df.loc["Intercept", ["Mean", "|0.89", "0.89|"]] -= center if corr: cov = vcov(posterior) corr = cov / cov.diag().ger(cov.diag()).sqrt() for i, node in enumerate(df.index): df[node] = corr[:, i] if isinstance(posterior, MCMC): diagnostics = posterior.marginal(df.index.tolist()).diagnostics() df = pd.concat([df, pd.DataFrame(diagnostics).T.astype(float)], axis=1) return df.round(digits)
def summary(samples, prob=0.9, group_by_chain=True): """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} summary_dict = {} for name, value in samples.items(): value_flat = torch.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(dim=0) std = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpdi = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) hpd_lower = '{:.1f}%'.format(50 * (1 - prob)) hpd_upper = '{:.1f}%'.format(50 * (1 + prob)) summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median), (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]), ("n_eff", n_eff), ("r_hat", r_hat)]) return summary_dict
def summary(samples, prob=0.9, group_by_chain=True): """ Prints a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} row_names = { k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']' for k, v in samples.items() } max_len = max(max(map(lambda x: len(x), row_names.values())), 10) name_format = '{:>' + str(max_len) + '}' header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}' columns = [ '', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)), '{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat' ] print('\n') print(header_format.format(*columns)) row_format = name_format + ' {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f}' for name, value in samples.items(): value_flat = torch.reshape(value, (-1, ) + value.shape[2:]) mean = value_flat.mean(dim=0) sd = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpd = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) shape = value_flat.shape[1:] if len(shape) == 0: print( row_format.format(name, mean, sd, median, hpd[0], hpd[1], n_eff, r_hat)) else: for idx in product(*map(range, shape)): idx_str = '[{}]'.format(','.join(map(str, idx))) print( row_format.format(name + idx_str, mean[idx], sd[idx], median[idx], hpd[0][idx], hpd[1][idx], n_eff[idx], r_hat[idx])) print('\n')
def get_hpdi_confidence_interval(samples, probs): pi = stats.hpdi(samples, prob=probs) below_interval_1 = (samples < pi[0].item()).sum().float() / len(samples) below_interval_2 = (samples < pi[1].item()).sum().float() / len(samples) df = { "LPI": [round(below_interval_1.item(), 2) * 100, round(pi[0].item(), 2)], "UPI": [ 100 - round(below_interval_1.item(), 2) * 100, round(pi[1].item(), 2) ] } df = pd.DataFrame.from_dict(df) return df
def precis(posterior, corr=False, digits=2): if isinstance(posterior, TracePosterior): node_supports = posterior.marginal( posterior.exec_traces[0].stochastic_nodes).support() else: node_supports = posterior mean, std_dev, lower_hpd, upper_hpd = {}, {}, {}, {} for node, support in node_supports.items(): mean[node] = support.mean().item() std_dev[node] = support.std().item() hpdi = stats.hpdi(support, prob=0.89) lower_hpd[node] = hpdi[0].item() upper_hpd[node] = hpdi[1].item() # correct `intercept` due to "centering trick" if isinstance(posterior, LM): center = posterior._get_centering_constant(mean) mean["intercept"] = mean["intercept"] - center lower_hpd["intercept"] = lower_hpd["intercept"] - center upper_hpd["intercept"] = upper_hpd["intercept"] - center precis = pd.DataFrame.from_dict({ "Mean": mean, "StdDev": std_dev, "|0.89": lower_hpd, "0.89|": upper_hpd }) if corr: cov = vcov(posterior) corr = cov / cov.diag().ger(cov.diag()).sqrt() corr_dict = {} pos = 0 for node in posterior.exec_traces[0].stochastic_nodes: corr_dict[node] = corr[:, pos].tolist() pos = pos + 1 precis = precis.assign(**corr_dict) return precis.round(digits)
def _hpdi(x, dim=0): return hpdi(x, prob=0.8, dim=dim)
def test_hpdi(): x = torch.randn(20000) assert_equal(hpdi(x, prob=0.8), pi(x, prob=0.8), prec=0.01) x = torch.empty(20000).exponential_(1) assert_equal(hpdi(x, prob=0.2), torch.tensor([0.0, 0.22]), prec=0.01)
def save_stats(model, path, CI=0.95, save_matlab=False): # global parameters global_params = model._global_params summary = pd.DataFrame( index=global_params, columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"], ) logger.info("- credible intervals") ci_stats = compute_ci(model, model.ci_params, CI) for param in global_params: if ci_stats[param]["Mean"].ndim == 0: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item() else: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].tolist() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].tolist() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].tolist() logger.info("- spot probabilities") ci_stats["m_probs"] = model.m_probs.data.cpu() ci_stats["theta_probs"] = model.theta_probs.data.cpu() ci_stats["z_probs"] = model.z_probs.data.cpu() ci_stats["z_map"] = model.z_map.data.cpu() ci_stats["p_specific"] = ci_stats["theta_probs"].sum(0) # calculate vmin/vmax theta_mask = ci_stats["theta_probs"] > 0.5 if theta_mask.sum(): hmax = np.percentile(ci_stats["height"]["Mean"][theta_mask], 99) else: hmax = 1 ci_stats["height"]["vmin"] = -0.03 * hmax ci_stats["height"]["vmax"] = 1.3 * hmax ci_stats["width"]["vmin"] = 0.5 ci_stats["width"]["vmax"] = 2.5 ci_stats["x"]["vmin"] = -9 ci_stats["x"]["vmax"] = 9 ci_stats["y"]["vmin"] = -9 ci_stats["y"]["vmax"] = 9 bmax = np.percentile(ci_stats["background"]["Mean"].flatten(), 99) ci_stats["background"]["vmin"] = -0.03 * bmax ci_stats["background"]["vmax"] = 1.3 * bmax # timestamps if model.data.time1 is not None: ci_stats["time1"] = model.data.time1 if model.data.ttb is not None: ci_stats["ttb"] = model.data.ttb # intensity of target-specific spots theta_mask = torch.argmax(ci_stats["theta_probs"], dim=0) # h_specific = Vindex(ci_stats["height"]["Mean"])[ # theta_mask, torch.arange(model.data.Nt)[:, None], torch.arange(model.data.F) # ] # ci_stats["h_specific"] = h_specific * (ci_stats["z_map"] > 0).long() model.params = ci_stats logger.info("- SNR and Chi2-test") # snr and chi2 test snr = torch.zeros(model.K, model.data.Nt, model.data.F, model.Q, device=torch.device("cpu")) chi2 = torch.zeros(model.data.Nt, model.data.F, model.Q, device=torch.device("cpu")) for n in range(model.data.Nt): snr[:, n], chi2[n] = snr_and_chi2( model.data.images[n], ci_stats["height"]["Mean"][:, n], ci_stats["width"]["Mean"][:, n], ci_stats["x"]["Mean"][:, n], ci_stats["y"]["Mean"][:, n], model.data.xy[n], ci_stats["background"]["Mean"][n], ci_stats["gain"]["Mean"], model.data.offset.mean, model.data.offset.var, model.data.P, ci_stats["theta_probs"][:, n], ) snr_masked = snr[ci_stats["theta_probs"] > 0.5] summary.loc["SNR", "Mean"] = snr_masked.mean().item() ci_stats["chi2"] = {} ci_stats["chi2"]["values"] = chi2 cmax = quantile(ci_stats["chi2"]["values"].flatten(), 0.99) ci_stats["chi2"]["vmin"] = -0.03 * cmax ci_stats["chi2"]["vmax"] = 1.3 * cmax # classification statistics if model.data.labels is not None: pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel() true_labels = model.data.labels["z"][:model.data.N].ravel() with np.errstate(divide="ignore", invalid="ignore"): summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels) summary.loc["Recall", "Mean"] = recall_score(true_labels, pred_labels, zero_division=0) summary.loc["Precision", "Mean"] = precision_score(true_labels, pred_labels, zero_division=0) ( summary.loc["TN", "Mean"], summary.loc["FP", "Mean"], summary.loc["FN", "Mean"], summary.loc["TP", "Mean"], ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel() mask = torch.from_numpy(model.data.labels["z"][:model.data.N]) > 0 samples = torch.masked_select( model.z_probs[model.data.is_ontarget].cpu(), mask) if len(samples): z_ll, z_ul = hpdi(samples, CI) summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item() summary.loc["p(specific)", "95% LL"] = z_ll.item() summary.loc["p(specific)", "95% UL"] = z_ul.item() else: summary.loc["p(specific)", "Mean"] = 0.0 summary.loc["p(specific)", "95% LL"] = 0.0 summary.loc["p(specific)", "95% UL"] = 0.0 model.summary = summary if path is not None: path = Path(path) param_path = path / f"{model.name}-params.tpqr" torch.save(ci_stats, param_path) logger.info(f"Parameters were saved in {param_path}") if save_matlab: from scipy.io import savemat for param, field in ci_stats.items(): if param in ( "m_probs", "theta_probs", "z_probs", "z_map", "p_specific", "h_specific", "time1", "ttb", ): ci_stats[param] = np.asarray(field) continue for stat, value in field.items(): ci_stats[param][stat] = np.asarray(value) mat_path = path / f"{model.name}-params.mat" savemat(mat_path, ci_stats) logger.info(f"Matlab parameters were saved in {mat_path}") csv_path = path / f"{model.name}-summary.csv" summary.to_csv(csv_path) logger.info(f"Summary statistics were saved in {csv_path}")
def save_stats(model, path, CI=0.95, save_matlab=False): # global parameters global_params = model._global_params summary = pd.DataFrame( index=global_params, columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"], ) # local parameters local_params = [ "height", "width", "x", "y", "background", ] ci_stats = defaultdict(partial(defaultdict, list)) num_samples = 10000 for param in global_params: if param == "gain": fn = dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ) elif param == "pi": fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")) elif param == "lamda": fn = dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ) elif param == "proximity": fn = AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (model.data.P + 1) / math.sqrt(12), ) elif param == "trans": fn = dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size") ).to_event(1) else: raise NotImplementedError samples = fn.sample((num_samples,)).data.squeeze() ci_stats[param] = {} LL, UL = hpdi( samples, CI, dim=0, ) ci_stats[param]["LL"] = LL.cpu() ci_stats[param]["UL"] = UL.cpu() ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu() # calculate Keq if param == "pi": ci_stats["Keq"] = {} LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0) ci_stats["Keq"]["LL"] = LL.cpu() ci_stats["Keq"]["UL"] = UL.cpu() ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu() # this does not need to be very accurate num_samples = 1000 for param in local_params: LL, UL, Mean = [], [], [] for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size): ndx = ndx[:, None] kdx = torch.arange(model.K)[:, None, None] ll, ul, mean = [], [], [] for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size): if param == "background": fn = dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ) elif param == "height": fn = dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ) elif param == "width": fn = AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ) elif param == "x": fn = AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) elif param == "y": fn = AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) else: raise NotImplementedError samples = fn.sample((num_samples,)).data l, u = hpdi( samples, CI, dim=0, ) m = fn.mean.data ll.append(l) ul.append(u) mean.append(m) else: LL.append(torch.cat(ll, -1)) UL.append(torch.cat(ul, -1)) Mean.append(torch.cat(mean, -1)) else: ci_stats[param]["LL"] = torch.cat(LL, -2).cpu() ci_stats[param]["UL"] = torch.cat(UL, -2).cpu() ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu() for param in global_params: if param == "pi": summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item() # Keq summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item() summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item() summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item() elif param == "trans": summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item() summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item() summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item() summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item() summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item() summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item() else: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item() ci_stats["m_probs"] = model.m_probs.data.cpu() ci_stats["theta_probs"] = model.theta_probs.data.cpu() ci_stats["z_probs"] = model.z_probs.data.cpu() ci_stats["z_map"] = model.z_map.data.cpu() # timestamps if model.data.time1 is not None: ci_stats["time1"] = model.data.time1 if model.data.ttb is not None: ci_stats["ttb"] = model.data.ttb model.params = ci_stats # snr summary.loc["SNR", "Mean"] = ( snr( model.data.images[:, :, model.cdx], ci_stats["width"]["Mean"], ci_stats["x"]["Mean"], ci_stats["y"]["Mean"], model.data.xy[:, :, model.cdx], ci_stats["background"]["Mean"], ci_stats["gain"]["Mean"], model.data.offset.mean, model.data.offset.var, model.data.P, model.theta_probs, ) .mean() .item() ) # classification statistics if model.data.labels is not None: pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel() true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel() with np.errstate(divide="ignore", invalid="ignore"): summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels) summary.loc["Recall", "Mean"] = recall_score( true_labels, pred_labels, zero_division=0 ) summary.loc["Precision", "Mean"] = precision_score( true_labels, pred_labels, zero_division=0 ) ( summary.loc["TN", "Mean"], summary.loc["FP", "Mean"], summary.loc["FN", "Mean"], summary.loc["TP", "Mean"], ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel() mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx]) samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask) if len(samples): z_ll, z_ul = hpdi(samples, CI) summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item() summary.loc["p(specific)", "95% LL"] = z_ll.item() summary.loc["p(specific)", "95% UL"] = z_ul.item() else: summary.loc["p(specific)", "Mean"] = 0.0 summary.loc["p(specific)", "95% LL"] = 0.0 summary.loc["p(specific)", "95% UL"] = 0.0 model.summary = summary if path is not None: path = Path(path) torch.save(ci_stats, path / f"{model.full_name}-params.tpqr") if save_matlab: from scipy.io import savemat for param, field in ci_stats.items(): if param in ( "m_probs", "theta_probs", "z_probs", "z_map", "time1", "ttb", ): ci_stats[param] = field.numpy() continue for stat, value in field.items(): ci_stats[param][stat] = value.cpu().numpy() savemat(path / f"{model.full_name}-params.mat", ci_stats) summary.to_csv( path / f"{model.full_name}-summary.csv", )