def test_quantile(): x = torch.tensor([0.0, 1.0, 2.0]) y = torch.rand(2000) z = torch.randn(2000) assert_equal(quantile(x, probs=[0.0, 0.4, 0.5, 1.0]), torch.tensor([0.0, 0.8, 1.0, 2.0])) assert_equal(quantile(y, probs=0.2), torch.tensor(0.2), prec=0.02) assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02)
def main(): inputs,calendar = load_input() logger.info('Inference') covariates, covariate_dim, data = inputs.values() data, covariates = map(jax_to_torch,[data,covariates]) data = torch.log(1+data.double()) assert pyro.__version__.startswith('1.3.1') pyro.enable_validation(True) T0 = 0 # begining T2 = data.size(-2) # end T1 = T2 - 500 # train/test split pyro.set_rng_seed(1) pyro.clear_param_store() data = data.permute(-2,-1) covariates = covariates.reshape(data.size(-1),T2,-1) # covariates = torch.zeros(len(data), 0) # empty forecaster = Forecaster(Model4(), data[:T1], covariates[:,:T1], learning_rate=0.09,num_steps=2000) samples = forecaster(data[:T1], covariates[:,:T2], num_samples=336) samples.clamp_(min=0) # apply domain knowledge: the samples must be positive p10, p50, p90 = quantile(samples[:, 0], [0.1, 0.5, 0.9]).squeeze(-1) crps = eval_crps(samples, data[T1:T2]) print(samples.shape, p10.shape) fig, axes = plt.subplots(data.size(-1), 1, figsize=(9, 10), sharex=True) plt.subplots_adjust(hspace=0) axes[0].set_title("Sales (CRPS = {:0.3g})".format(crps)) for i, ax in enumerate(axes): ax.fill_between(torch.arange(T1, T2), p10[:, i], p90[:, i], color="red", alpha=0.3) ax.plot(torch.arange(T1, T2), p50[:, i], 'r-', lw=1, label='forecast') ax.plot(torch.arange(T0, T2),data[: T2, i], 'k-', lw=1, label='truth') ax.set_ylabel(f"item: {i}") axes[0].legend(loc="best") plt.show() plt.savefig('figures/pyro_forecast.png')
def vmin(self) -> torch.Tensor: return torch.stack( [ quantile(self.images[..., c, :, :].flatten().float(), 0.05) for c in range(self.C) ] )
def vmax(self) -> int: return torch.stack( [ quantile(self.images[..., c, :, :].flatten().float(), 0.99) for c in range(self.C) ] )
def main(): data = torch.load('data.pt').transpose(-1, -2) data = data[0] data = data[:, None] pyro.set_rng_seed(1) pyro.clear_param_store() covariates = torch.zeros(len(data), 0) forecaster = Forecaster(MM(), data[:700], covariates[:700], learning_rate=0.1, num_steps=400) for name, value in forecaster.guide.median().items(): if value.numel() == 1: print("{} = {:0.4g}".format(name, value.item())) samples = forecaster(data[:700], covariates, num_samples=100) p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1) eval_crps(samples, data[700:]) plt.figure(figsize=(9, 3)) plt.fill_between(torch.arange(700, 1404), p10, p90, color="red", alpha=0.3) plt.plot(torch.arange(700, 1404), p50, 'r-', label='forecast') plt.plot(torch.arange(700, 1404), data[700:1404], 'k-', label='truth') plt.xlim(700, 1404) plt.legend(loc="best") plt.show()
def get_counterfactual(data_dict, forecaster, res, R0): p_fatal = res['p_fatal'] case_import = res['case_import'] / forecaster.model.N infectious_days = res['infect_days'] time_to_death = res['time_to_death'] d_incubation = res['d_incubation'] alpha = 1. / d_incubation d_death = time_to_death - infectious_days e0 = case_import i0 = torch.zeros_like(e0) r0 = torch.zeros_like(e0) f0 = torch.zeros_like(e0) s0 = 1. - e0 - i0 - r0 - f0 sigma = 1. / infectious_days beta_t = sigma * R0 # t_init: same shape with i0 t_init = data_dict['t_init'].unsqueeze(0).repeat(i0.size(0), 1, 1) res_ode = pyro_model.helper.eluer_seir_time(s0, e0, i0, r0, f0, beta_t, sigma, alpha, p_fatal, d_death, case_import, t_init) r_t = res_ode[-1] prediction = r_t * forecaster.model.N # ..., t, p prediction = prediction.unsqueeze(-1).transpose(-1, -3) mask_half = res['mask_half'] mask_full = torch.cat([mask_half, torch.zeros_like(mask_half)], dim=-1)[..., :-1] res_list = [] for i in range(prediction.shape[0]): pred_temp = prediction[i, ...] mask_temp = mask_full[i, ...][0, ...].permute(1, 0, 2) prediction_conv = pred_temp.permute(2, 0, 1) res_inner_list = [] for j in range(len(forecaster.model.N)): res_inner = torch.nn.functional.conv1d( prediction_conv[j:j + 1, ...], mask_temp[j:j + 1, ...], padding=mask_half.shape[-1] - 1) res_inner_list.append(res_inner) res1 = torch.cat(res_inner_list, dim=0) res1 = res1.permute(1, 2, 0) res_list.append(res1) prediction = torch.stack(res_list, dim=0).numpy().squeeze() prediction = np.diff(prediction, axis=1) prediction = quantile(torch.tensor(prediction), (0.05, 0.5, 0.95), dim=0).numpy() return prediction
def vmin(self) -> int: return quantile(self.images.flatten().float(), 0.05).item()
for seed in range(25): model_id = 'day-{}-rng-{}'.format(days, seed) try: with open(prefix + 'Loop{}/{}-predictive.pkl'.format(days, model_id), 'rb') as f: res = pickle.load(f) except Exception: continue with open(prefix + 'Loop{}/{}-forecaster.csv'.format(days, model_id), 'rb') as f: forecaster = pickle.load(f) prediction = quantile(res['prediction'].squeeze(), (0.5, ), dim=0).squeeze() err = np.diff(prediction, axis=0) - data_dict['daily_death'][1:, ].numpy() err_country = np.sum(np.abs(err), axis=0) err_country_list.append(err_country) seed_list.append(seed) err = np.stack(err_country_list, axis=0) best_model = np.argmin(err, axis=0) best_err = np.min(err, axis=0) best_seed = [seed_list[x] for x in best_model] df = pds.DataFrame({ 'country': countries, 'best_seed': best_seed, 'best_err': best_err
continue with open('Loop{}/{}-forecaster.csv'.format(days, model_id), 'wb') as f: pickle.dump(forecaster, f) samples = forecaster(Y_train, covariates_full_notime, num_samples=n_sample, batch_size=50) samples = samples.squeeze() init = Y_train[-1, :][None, None, :] init = init.repeat(samples.shape[0], 1, 1) samples = torch.cat([init, samples], dim=1) daily_s = samples[:, 1:, :] - samples[:, :-1, :] p10, p50, p90 = quantile(daily_s, (0.1, 0.5, 0.9), dim=0).squeeze(-1) rmse = torch.sqrt( torch.mean((p50[-days:, :] - Y_daily[-days:, :])**2, dim=0)).squeeze().numpy() off = (torch.sum(p50[-days:, :], dim=0) - torch.sum(Y_daily[-days:, :], dim=0)).squeeze().numpy() for i in zip(countries, rmse, off): print(i) d = {'countries': countries, 'rmse': rmse, 'total_error': off} df = pds.DataFrame(data=d) df.to_csv('Loop{}/{}-rmse.csv'.format(days, model_id)) with open('Loop{}/{}-samples.pkl'.format(days, model_id), 'wb') as f: pickle.dump(samples.detach().numpy(), f) R0low, R0mid, R0high, map_estimates = model.get_R0(forecaster, Y_train,
predictive_list.append(predictive) with open(prefix + 'Loop{}/{}-samples.pkl'.format(days, model_id), 'rb') as f: samples = pickle.load(f) samples_list.append(samples) seed_list.append(seed) # validation accuracy val_window = 14 seir_error_list = [] for i in range(len(predictive_list)): seir_train = quantile(predictive_list[i]['prediction'].squeeze(), 0.5, dim=0)[-val_window + 1:, :].numpy() seir_train = np.diff(seir_train, axis=0) seir_label = data_dict['daily_death'][train_len - val_window:train_len, :].numpy() seir_error = np.abs( np.sum(seir_train, axis=0) - np.sum(seir_label, axis=0)) seir_error_list.append(seir_error) seir_error = np.stack(seir_error_list, axis=0) best_model = np.argmin(seir_error, axis=0) best_seed = [seed_list[x] for x in best_model] test_len = 14
def save_stats(model, path, CI=0.95, save_matlab=False): # global parameters global_params = model._global_params summary = pd.DataFrame( index=global_params, columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"], ) # local parameters local_params = [ "height", "width", "x", "y", "background", ] ci_stats = defaultdict(partial(defaultdict, list)) num_samples = 10000 for param in global_params: if param == "gain": fn = dist.Gamma( pyro.param("gain_loc") * pyro.param("gain_beta"), pyro.param("gain_beta"), ) elif param == "pi": fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")) elif param == "lamda": fn = dist.Gamma( pyro.param("lamda_loc") * pyro.param("lamda_beta"), pyro.param("lamda_beta"), ) elif param == "proximity": fn = AffineBeta( pyro.param("proximity_loc"), pyro.param("proximity_size"), 0, (model.data.P + 1) / math.sqrt(12), ) elif param == "trans": fn = dist.Dirichlet( pyro.param("trans_mean") * pyro.param("trans_size") ).to_event(1) else: raise NotImplementedError samples = fn.sample((num_samples,)).data.squeeze() ci_stats[param] = {} LL, UL = hpdi( samples, CI, dim=0, ) ci_stats[param]["LL"] = LL.cpu() ci_stats[param]["UL"] = UL.cpu() ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu() # calculate Keq if param == "pi": ci_stats["Keq"] = {} LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0) ci_stats["Keq"]["LL"] = LL.cpu() ci_stats["Keq"]["UL"] = UL.cpu() ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu() # this does not need to be very accurate num_samples = 1000 for param in local_params: LL, UL, Mean = [], [], [] for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size): ndx = ndx[:, None] kdx = torch.arange(model.K)[:, None, None] ll, ul, mean = [], [], [] for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size): if param == "background": fn = dist.Gamma( Vindex(pyro.param("b_loc"))[ndx, fdx] * Vindex(pyro.param("b_beta"))[ndx, fdx], Vindex(pyro.param("b_beta"))[ndx, fdx], ) elif param == "height": fn = dist.Gamma( Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], Vindex(pyro.param("h_beta"))[kdx, ndx, fdx], ) elif param == "width": fn = AffineBeta( Vindex(pyro.param("w_mean"))[kdx, ndx, fdx], Vindex(pyro.param("w_size"))[kdx, ndx, fdx], 0.75, 2.25, ) elif param == "x": fn = AffineBeta( Vindex(pyro.param("x_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) elif param == "y": fn = AffineBeta( Vindex(pyro.param("y_mean"))[kdx, ndx, fdx], Vindex(pyro.param("size"))[kdx, ndx, fdx], -(model.data.P + 1) / 2, (model.data.P + 1) / 2, ) else: raise NotImplementedError samples = fn.sample((num_samples,)).data l, u = hpdi( samples, CI, dim=0, ) m = fn.mean.data ll.append(l) ul.append(u) mean.append(m) else: LL.append(torch.cat(ll, -1)) UL.append(torch.cat(ul, -1)) Mean.append(torch.cat(mean, -1)) else: ci_stats[param]["LL"] = torch.cat(LL, -2).cpu() ci_stats[param]["UL"] = torch.cat(UL, -2).cpu() ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu() for param in global_params: if param == "pi": summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item() # Keq summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item() summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item() summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item() elif param == "trans": summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item() summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item() summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item() summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item() summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item() summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item() else: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item() ci_stats["m_probs"] = model.m_probs.data.cpu() ci_stats["theta_probs"] = model.theta_probs.data.cpu() ci_stats["z_probs"] = model.z_probs.data.cpu() ci_stats["z_map"] = model.z_map.data.cpu() # timestamps if model.data.time1 is not None: ci_stats["time1"] = model.data.time1 if model.data.ttb is not None: ci_stats["ttb"] = model.data.ttb model.params = ci_stats # snr summary.loc["SNR", "Mean"] = ( snr( model.data.images[:, :, model.cdx], ci_stats["width"]["Mean"], ci_stats["x"]["Mean"], ci_stats["y"]["Mean"], model.data.xy[:, :, model.cdx], ci_stats["background"]["Mean"], ci_stats["gain"]["Mean"], model.data.offset.mean, model.data.offset.var, model.data.P, model.theta_probs, ) .mean() .item() ) # classification statistics if model.data.labels is not None: pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel() true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel() with np.errstate(divide="ignore", invalid="ignore"): summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels) summary.loc["Recall", "Mean"] = recall_score( true_labels, pred_labels, zero_division=0 ) summary.loc["Precision", "Mean"] = precision_score( true_labels, pred_labels, zero_division=0 ) ( summary.loc["TN", "Mean"], summary.loc["FP", "Mean"], summary.loc["FN", "Mean"], summary.loc["TP", "Mean"], ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel() mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx]) samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask) if len(samples): z_ll, z_ul = hpdi(samples, CI) summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item() summary.loc["p(specific)", "95% LL"] = z_ll.item() summary.loc["p(specific)", "95% UL"] = z_ul.item() else: summary.loc["p(specific)", "Mean"] = 0.0 summary.loc["p(specific)", "95% LL"] = 0.0 summary.loc["p(specific)", "95% UL"] = 0.0 model.summary = summary if path is not None: path = Path(path) torch.save(ci_stats, path / f"{model.full_name}-params.tpqr") if save_matlab: from scipy.io import savemat for param, field in ci_stats.items(): if param in ( "m_probs", "theta_probs", "z_probs", "z_map", "time1", "ttb", ): ci_stats[param] = field.numpy() continue for stat, value in field.items(): ci_stats[param][stat] = value.cpu().numpy() savemat(path / f"{model.full_name}-params.mat", ci_stats) summary.to_csv( path / f"{model.full_name}-summary.csv", )
with open(prefix + 'AblationLoop{}/{}-predictive.pkl'.format(days, model_id), 'rb') as f: predictive = pickle.load(f) except Exception: continue predictive_list.append(predictive) with open(prefix + 'AblationLoop{}/{}-samples.pkl'.format(days, model_id), 'rb') as f: samples = pickle.load(f) samples_list.append(samples) val_window = 14 seir_error_list = [] for i in range(len(predictive_list)): seir_train = quantile(predictive_list[i]['prediction'].squeeze(), 0.5, dim=0)[-val_window + 1:].numpy() seir_train = np.diff(seir_train, axis=0) seir_label = data_dict['daily_death'][train_len - val_window:train_len, :].numpy() seir_error = np.abs(np.sum(seir_train, axis=0) - np.sum(seir_label, axis=0)) seir_error_list.append(seir_error) seir_error = np.stack(seir_error_list, axis=0) best_model = np.argmin(seir_error, axis=0) test_len = 14 best_error_list = [] test_len = test_len - 1 for j, i in zip(range(len(countries)), best_model): c = countries[j]
def save_stats(model, path, CI=0.95, save_matlab=False): # global parameters global_params = model._global_params summary = pd.DataFrame( index=global_params, columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"], ) logger.info("- credible intervals") ci_stats = compute_ci(model, model.ci_params, CI) for param in global_params: if ci_stats[param]["Mean"].ndim == 0: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item() else: summary.loc[param, "Mean"] = ci_stats[param]["Mean"].tolist() summary.loc[param, "95% LL"] = ci_stats[param]["LL"].tolist() summary.loc[param, "95% UL"] = ci_stats[param]["UL"].tolist() logger.info("- spot probabilities") ci_stats["m_probs"] = model.m_probs.data.cpu() ci_stats["theta_probs"] = model.theta_probs.data.cpu() ci_stats["z_probs"] = model.z_probs.data.cpu() ci_stats["z_map"] = model.z_map.data.cpu() ci_stats["p_specific"] = ci_stats["theta_probs"].sum(0) # calculate vmin/vmax theta_mask = ci_stats["theta_probs"] > 0.5 if theta_mask.sum(): hmax = np.percentile(ci_stats["height"]["Mean"][theta_mask], 99) else: hmax = 1 ci_stats["height"]["vmin"] = -0.03 * hmax ci_stats["height"]["vmax"] = 1.3 * hmax ci_stats["width"]["vmin"] = 0.5 ci_stats["width"]["vmax"] = 2.5 ci_stats["x"]["vmin"] = -9 ci_stats["x"]["vmax"] = 9 ci_stats["y"]["vmin"] = -9 ci_stats["y"]["vmax"] = 9 bmax = np.percentile(ci_stats["background"]["Mean"].flatten(), 99) ci_stats["background"]["vmin"] = -0.03 * bmax ci_stats["background"]["vmax"] = 1.3 * bmax # timestamps if model.data.time1 is not None: ci_stats["time1"] = model.data.time1 if model.data.ttb is not None: ci_stats["ttb"] = model.data.ttb # intensity of target-specific spots theta_mask = torch.argmax(ci_stats["theta_probs"], dim=0) # h_specific = Vindex(ci_stats["height"]["Mean"])[ # theta_mask, torch.arange(model.data.Nt)[:, None], torch.arange(model.data.F) # ] # ci_stats["h_specific"] = h_specific * (ci_stats["z_map"] > 0).long() model.params = ci_stats logger.info("- SNR and Chi2-test") # snr and chi2 test snr = torch.zeros(model.K, model.data.Nt, model.data.F, model.Q, device=torch.device("cpu")) chi2 = torch.zeros(model.data.Nt, model.data.F, model.Q, device=torch.device("cpu")) for n in range(model.data.Nt): snr[:, n], chi2[n] = snr_and_chi2( model.data.images[n], ci_stats["height"]["Mean"][:, n], ci_stats["width"]["Mean"][:, n], ci_stats["x"]["Mean"][:, n], ci_stats["y"]["Mean"][:, n], model.data.xy[n], ci_stats["background"]["Mean"][n], ci_stats["gain"]["Mean"], model.data.offset.mean, model.data.offset.var, model.data.P, ci_stats["theta_probs"][:, n], ) snr_masked = snr[ci_stats["theta_probs"] > 0.5] summary.loc["SNR", "Mean"] = snr_masked.mean().item() ci_stats["chi2"] = {} ci_stats["chi2"]["values"] = chi2 cmax = quantile(ci_stats["chi2"]["values"].flatten(), 0.99) ci_stats["chi2"]["vmin"] = -0.03 * cmax ci_stats["chi2"]["vmax"] = 1.3 * cmax # classification statistics if model.data.labels is not None: pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel() true_labels = model.data.labels["z"][:model.data.N].ravel() with np.errstate(divide="ignore", invalid="ignore"): summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels) summary.loc["Recall", "Mean"] = recall_score(true_labels, pred_labels, zero_division=0) summary.loc["Precision", "Mean"] = precision_score(true_labels, pred_labels, zero_division=0) ( summary.loc["TN", "Mean"], summary.loc["FP", "Mean"], summary.loc["FN", "Mean"], summary.loc["TP", "Mean"], ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel() mask = torch.from_numpy(model.data.labels["z"][:model.data.N]) > 0 samples = torch.masked_select( model.z_probs[model.data.is_ontarget].cpu(), mask) if len(samples): z_ll, z_ul = hpdi(samples, CI) summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item() summary.loc["p(specific)", "95% LL"] = z_ll.item() summary.loc["p(specific)", "95% UL"] = z_ul.item() else: summary.loc["p(specific)", "Mean"] = 0.0 summary.loc["p(specific)", "95% LL"] = 0.0 summary.loc["p(specific)", "95% UL"] = 0.0 model.summary = summary if path is not None: path = Path(path) param_path = path / f"{model.name}-params.tpqr" torch.save(ci_stats, param_path) logger.info(f"Parameters were saved in {param_path}") if save_matlab: from scipy.io import savemat for param, field in ci_stats.items(): if param in ( "m_probs", "theta_probs", "z_probs", "z_map", "p_specific", "h_specific", "time1", "ttb", ): ci_stats[param] = np.asarray(field) continue for stat, value in field.items(): ci_stats[param][stat] = np.asarray(value) mat_path = path / f"{model.name}-params.mat" savemat(mat_path, ci_stats) logger.info(f"Matlab parameters were saved in {mat_path}") csv_path = path / f"{model.name}-summary.csv" summary.to_csv(csv_path) logger.info(f"Summary statistics were saved in {csv_path}")
seed = 42 np.random.seed(seed) pyro.set_rng_seed(seed) with pyro.plate("respirator", 1000, dim=-2): surface_area, layer_params = respirator_model_charge_prior() results: torch.Tensor = penetration(surface_area, layer_params) plt.style.use('ggplot') plt.rcParams['text.usetex'] = True THRESHOLD = 6e-2 # 6% threshold # plt.plot(particle_diam, results) low_bound, high_bound = quantile(results, (0.05, 0.95)) fig: plt.Figure = plt.figure() plt.plot(particle_diam, results.mean(dim=0), label="Mean") plt.fill_between(particle_diam, low_bound, high_bound, alpha=.4, label=r"95\% confidence interval") plt.hlines(THRESHOLD, particle_diam.min(), particle_diam.max(), ls='--', label="Penetration threshold") plt.xlabel("Particle size $d_p$") plt.ylabel("Penetration") plt.xscale('log') plt.title("Penetration profile. Prior charge density $q \sim \\mathcal{N}(13, 1)$ nm") # plt.title("Prior charge density $q \sim \\mathcal{U}(13, 14)$ nm") plt.legend() plt.tight_layout() plt.show()
def test_pi(): x = torch.randn(1000).exp() assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
def get_R0(self, forecaster, Y_train, covariates_full, sample_size, batch_size): n_batch = sample_size // batch_size R_list = [] map_list = [] for i in range(n_batch): with torch.no_grad(): with pyro.plate("particles", batch_size, dim=-3): map_estimates = forecaster.guide(Y_train[:1, :], covariates_full) map_list.append(map_estimates) data_dim = Y_train.size(-1) time_dim = covariates_full.size(-2) kernel_lengthscale = map_estimates['r0_lengthscale'] kernel_var = map_estimates['r0_kernel_var'] covariates_pyro = pyro_model.helper.reshape_covariates_pyro( covariates_full, self.n_country).to(torch.float) kernel_lengthscale = kernel_lengthscale.transpose(-1, -2) kernel_var = kernel_var.transpose(-1, -2) # covariate covariates_pyro = covariates_pyro.transpose(1, 0) d_times_t = data_dim * time_dim # dxt, 1, p covariates_pyro = covariates_pyro.reshape( d_times_t, 1, covariates_pyro.size(-1)) var = pyro_model.helper.tensor_RBF(kernel_lengthscale, kernel_var, covariates_pyro) var = var + torch.eye(var.shape[-1]) * 0.01 A = torch.cholesky(var) iid_n = map_estimates['r0_iid_n'] iid_n = torch.cat([ iid_n, torch.randn(iid_n.size(0), iid_n.size(1), time_dim - iid_n.size(-1)) ], dim=-1) iid_n = iid_n.unsqueeze(-1) iid_n = iid_n.reshape(iid_n.size(0), 1, d_times_t, 1) weight = torch.sigmoid( torch.einsum('abij,abjk->abik', A, iid_n)) weight = weight[:, 0, ...] # get rid of batch dimension weight = weight.reshape(weight.size(0), data_dim, time_dim, 1) R00 = map_estimates['R00'] R0 = R00 * weight[..., 0] R_list.append(R0) R0 = torch.cat(R_list, dim=0) R0low, R0mid, R0high = quantile(R0, (0.1, 0.5, 0.9), dim=0).squeeze(-1) map_estimates = dict.fromkeys(map_list[0].keys()) for k in map_list[0].keys(): k_list = [] for m in map_list: k_list.append(m[k]) kest = torch.cat(k_list, dim=0) map_estimates[k] = kest return R0low, R0mid, R0high, map_estimates
days = 0 train_len = data_dict['cum_death'].shape[0] - days covariates_actual = pyro_model.helper.get_covariates_intervention(data_dict, train_len, notime=True) Y_train = pyro_model.helper.get_Y(data_dict, train_len) seed = 1 model_id = 'day-{}-rng-{}'.format(days, seed) with open(prefix + 'Loop{}/{}-predictive.pkl'.format(days, model_id), 'rb') as f: res = pickle.load(f) with open(prefix + 'Loop{}/{}-forecaster.csv'.format(days, model_id), 'rb') as f: forecaster = pickle.load(f) prediction = quantile(res['prediction'].squeeze(), (0.5,), dim=0).squeeze() c = 0 dt_list = data_dict['date_list'] start_date = data_dict['t_init'][c] + pad plt.plot(dt_list[:-days - 1], np.diff(prediction[:, c]), label='Fitted SEIR') plt.plot(dt_list, data_dict['daily_death'][:, c], '.', label='acutal') plt.gcf().autofmt_xdate() plt.title(countries[c]) plt.legend() R0 = res['R0'].squeeze()
def test_pi(): x = torch.empty(1000).log_normal_(0, 1) assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
def _quantile(x, dim=0): return quantile(x, probs=[0.1, 0.6], dim=dim)
def vmax(self) -> int: return quantile(self.images.flatten().float(), 0.99).item()