Ejemplo n.º 1
0
def test_quantile():
    x = torch.tensor([0.0, 1.0, 2.0])
    y = torch.rand(2000)
    z = torch.randn(2000)

    assert_equal(quantile(x, probs=[0.0, 0.4, 0.5, 1.0]),
                 torch.tensor([0.0, 0.8, 1.0, 2.0]))
    assert_equal(quantile(y, probs=0.2), torch.tensor(0.2), prec=0.02)
    assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02)
Ejemplo n.º 2
0
def main():
    inputs,calendar = load_input()
    logger.info('Inference')
    covariates, covariate_dim, data = inputs.values()
    data, covariates = map(jax_to_torch,[data,covariates])
    data = torch.log(1+data.double())
    assert pyro.__version__.startswith('1.3.1')
    pyro.enable_validation(True)
    T0 = 0  # begining
    T2 = data.size(-2)  # end
    T1 = T2 - 500  # train/test split
    pyro.set_rng_seed(1)
    pyro.clear_param_store()
    data = data.permute(-2,-1)
    covariates = covariates.reshape(data.size(-1),T2,-1)
    # covariates = torch.zeros(len(data), 0)  # empty
    forecaster = Forecaster(Model4(), data[:T1], covariates[:,:T1], learning_rate=0.09,num_steps=2000)
    samples = forecaster(data[:T1], covariates[:,:T2], num_samples=336)
    samples.clamp_(min=0)  # apply domain knowledge: the samples must be positive
    p10, p50, p90 = quantile(samples[:, 0], [0.1, 0.5, 0.9]).squeeze(-1)
    crps = eval_crps(samples, data[T1:T2])
    print(samples.shape, p10.shape)
    fig, axes = plt.subplots(data.size(-1), 1, figsize=(9, 10), sharex=True)
    plt.subplots_adjust(hspace=0)
    axes[0].set_title("Sales (CRPS = {:0.3g})".format(crps))
    for i, ax in enumerate(axes):
        ax.fill_between(torch.arange(T1, T2), p10[:, i], p90[:, i], color="red", alpha=0.3)
        ax.plot(torch.arange(T1, T2), p50[:, i], 'r-', lw=1, label='forecast')
        ax.plot(torch.arange(T0, T2),data[: T2, i], 'k-', lw=1, label='truth')
        ax.set_ylabel(f"item: {i}")
    axes[0].legend(loc="best")
    plt.show()
    plt.savefig('figures/pyro_forecast.png')
Ejemplo n.º 3
0
 def vmin(self) -> torch.Tensor:
     return torch.stack(
         [
             quantile(self.images[..., c, :, :].flatten().float(), 0.05)
             for c in range(self.C)
         ]
     )
Ejemplo n.º 4
0
 def vmax(self) -> int:
     return torch.stack(
         [
             quantile(self.images[..., c, :, :].flatten().float(), 0.99)
             for c in range(self.C)
         ]
     )
Ejemplo n.º 5
0
def main():
    data = torch.load('data.pt').transpose(-1, -2)
    data = data[0]
    data = data[:, None]
    pyro.set_rng_seed(1)
    pyro.clear_param_store()
    covariates = torch.zeros(len(data), 0)
    forecaster = Forecaster(MM(),
                            data[:700],
                            covariates[:700],
                            learning_rate=0.1,
                            num_steps=400)
    for name, value in forecaster.guide.median().items():
        if value.numel() == 1:
            print("{} = {:0.4g}".format(name, value.item()))

    samples = forecaster(data[:700], covariates, num_samples=100)
    p10, p50, p90 = quantile(samples, (0.1, 0.5, 0.9)).squeeze(-1)
    eval_crps(samples, data[700:])
    plt.figure(figsize=(9, 3))
    plt.fill_between(torch.arange(700, 1404), p10, p90, color="red", alpha=0.3)
    plt.plot(torch.arange(700, 1404), p50, 'r-', label='forecast')
    plt.plot(torch.arange(700, 1404), data[700:1404], 'k-', label='truth')
    plt.xlim(700, 1404)
    plt.legend(loc="best")
    plt.show()
Ejemplo n.º 6
0
def get_counterfactual(data_dict, forecaster, res, R0):
    p_fatal = res['p_fatal']
    case_import = res['case_import'] / forecaster.model.N

    infectious_days = res['infect_days']
    time_to_death = res['time_to_death']
    d_incubation = res['d_incubation']
    alpha = 1. / d_incubation
    d_death = time_to_death - infectious_days

    e0 = case_import
    i0 = torch.zeros_like(e0)
    r0 = torch.zeros_like(e0)
    f0 = torch.zeros_like(e0)
    s0 = 1. - e0 - i0 - r0 - f0

    sigma = 1. / infectious_days
    beta_t = sigma * R0

    # t_init: same shape with i0
    t_init = data_dict['t_init'].unsqueeze(0).repeat(i0.size(0), 1, 1)

    res_ode = pyro_model.helper.eluer_seir_time(s0, e0, i0, r0, f0, beta_t,
                                                sigma, alpha, p_fatal, d_death,
                                                case_import, t_init)
    r_t = res_ode[-1]
    prediction = r_t * forecaster.model.N
    # ..., t, p
    prediction = prediction.unsqueeze(-1).transpose(-1, -3)

    mask_half = res['mask_half']
    mask_full = torch.cat([mask_half, torch.zeros_like(mask_half)],
                          dim=-1)[..., :-1]

    res_list = []
    for i in range(prediction.shape[0]):
        pred_temp = prediction[i, ...]
        mask_temp = mask_full[i, ...][0, ...].permute(1, 0, 2)
        prediction_conv = pred_temp.permute(2, 0, 1)

        res_inner_list = []

        for j in range(len(forecaster.model.N)):
            res_inner = torch.nn.functional.conv1d(
                prediction_conv[j:j + 1, ...],
                mask_temp[j:j + 1, ...],
                padding=mask_half.shape[-1] - 1)
            res_inner_list.append(res_inner)

        res1 = torch.cat(res_inner_list, dim=0)
        res1 = res1.permute(1, 2, 0)
        res_list.append(res1)
    prediction = torch.stack(res_list, dim=0).numpy().squeeze()
    prediction = np.diff(prediction, axis=1)

    prediction = quantile(torch.tensor(prediction), (0.05, 0.5, 0.95),
                          dim=0).numpy()
    return prediction
Ejemplo n.º 7
0
 def vmin(self) -> int:
     return quantile(self.images.flatten().float(), 0.05).item()
Ejemplo n.º 8
0
for seed in range(25):
    model_id = 'day-{}-rng-{}'.format(days, seed)

    try:
        with open(prefix + 'Loop{}/{}-predictive.pkl'.format(days, model_id),
                  'rb') as f:
            res = pickle.load(f)
    except Exception:
        continue

    with open(prefix + 'Loop{}/{}-forecaster.csv'.format(days, model_id),
              'rb') as f:
        forecaster = pickle.load(f)

    prediction = quantile(res['prediction'].squeeze(), (0.5, ),
                          dim=0).squeeze()
    err = np.diff(prediction, axis=0) - data_dict['daily_death'][1:, ].numpy()
    err_country = np.sum(np.abs(err), axis=0)
    err_country_list.append(err_country)
    seed_list.append(seed)

err = np.stack(err_country_list, axis=0)

best_model = np.argmin(err, axis=0)
best_err = np.min(err, axis=0)
best_seed = [seed_list[x] for x in best_model]

df = pds.DataFrame({
    'country': countries,
    'best_seed': best_seed,
    'best_err': best_err
Ejemplo n.º 9
0
        continue

    with open('Loop{}/{}-forecaster.csv'.format(days, model_id), 'wb') as f:
        pickle.dump(forecaster, f)

    samples = forecaster(Y_train,
                         covariates_full_notime,
                         num_samples=n_sample,
                         batch_size=50)

    samples = samples.squeeze()
    init = Y_train[-1, :][None, None, :]
    init = init.repeat(samples.shape[0], 1, 1)
    samples = torch.cat([init, samples], dim=1)
    daily_s = samples[:, 1:, :] - samples[:, :-1, :]
    p10, p50, p90 = quantile(daily_s, (0.1, 0.5, 0.9), dim=0).squeeze(-1)

    rmse = torch.sqrt(
        torch.mean((p50[-days:, :] - Y_daily[-days:, :])**2,
                   dim=0)).squeeze().numpy()
    off = (torch.sum(p50[-days:, :], dim=0) -
           torch.sum(Y_daily[-days:, :], dim=0)).squeeze().numpy()
    for i in zip(countries, rmse, off):
        print(i)
    d = {'countries': countries, 'rmse': rmse, 'total_error': off}
    df = pds.DataFrame(data=d)
    df.to_csv('Loop{}/{}-rmse.csv'.format(days, model_id))

    with open('Loop{}/{}-samples.pkl'.format(days, model_id), 'wb') as f:
        pickle.dump(samples.detach().numpy(), f)
    R0low, R0mid, R0high, map_estimates = model.get_R0(forecaster, Y_train,
Ejemplo n.º 10
0
    predictive_list.append(predictive)

    with open(prefix + 'Loop{}/{}-samples.pkl'.format(days, model_id),
              'rb') as f:
        samples = pickle.load(f)
    samples_list.append(samples)
    seed_list.append(seed)

# validation accuracy
val_window = 14

seir_error_list = []

for i in range(len(predictive_list)):
    seir_train = quantile(predictive_list[i]['prediction'].squeeze(),
                          0.5,
                          dim=0)[-val_window + 1:, :].numpy()
    seir_train = np.diff(seir_train, axis=0)
    seir_label = data_dict['daily_death'][train_len -
                                          val_window:train_len, :].numpy()

    seir_error = np.abs(
        np.sum(seir_train, axis=0) - np.sum(seir_label, axis=0))
    seir_error_list.append(seir_error)

seir_error = np.stack(seir_error_list, axis=0)
best_model = np.argmin(seir_error, axis=0)

best_seed = [seed_list[x] for x in best_model]

test_len = 14
Ejemplo n.º 11
0
def save_stats(model, path, CI=0.95, save_matlab=False):
    # global parameters
    global_params = model._global_params
    summary = pd.DataFrame(
        index=global_params,
        columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"],
    )
    # local parameters
    local_params = [
        "height",
        "width",
        "x",
        "y",
        "background",
    ]

    ci_stats = defaultdict(partial(defaultdict, list))
    num_samples = 10000
    for param in global_params:
        if param == "gain":
            fn = dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            )
        elif param == "pi":
            fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size"))
        elif param == "lamda":
            fn = dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            )
        elif param == "proximity":
            fn = AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (model.data.P + 1) / math.sqrt(12),
            )
        elif param == "trans":
            fn = dist.Dirichlet(
                pyro.param("trans_mean") * pyro.param("trans_size")
            ).to_event(1)
        else:
            raise NotImplementedError
        samples = fn.sample((num_samples,)).data.squeeze()
        ci_stats[param] = {}
        LL, UL = hpdi(
            samples,
            CI,
            dim=0,
        )
        ci_stats[param]["LL"] = LL.cpu()
        ci_stats[param]["UL"] = UL.cpu()
        ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu()

        # calculate Keq
        if param == "pi":
            ci_stats["Keq"] = {}
            LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0)
            ci_stats["Keq"]["LL"] = LL.cpu()
            ci_stats["Keq"]["UL"] = UL.cpu()
            ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu()

    # this does not need to be very accurate
    num_samples = 1000
    for param in local_params:
        LL, UL, Mean = [], [], []
        for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size):
            ndx = ndx[:, None]
            kdx = torch.arange(model.K)[:, None, None]
            ll, ul, mean = [], [], []
            for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size):
                if param == "background":
                    fn = dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx]
                        * Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    )
                elif param == "height":
                    fn = dist.Gamma(
                        Vindex(pyro.param("h_loc"))[kdx, ndx, fdx]
                        * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                        Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                    )
                elif param == "width":
                    fn = AffineBeta(
                        Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                        0.75,
                        2.25,
                    )
                elif param == "x":
                    fn = AffineBeta(
                        Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                elif param == "y":
                    fn = AffineBeta(
                        Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                else:
                    raise NotImplementedError
                samples = fn.sample((num_samples,)).data
                l, u = hpdi(
                    samples,
                    CI,
                    dim=0,
                )
                m = fn.mean.data
                ll.append(l)
                ul.append(u)
                mean.append(m)
            else:
                LL.append(torch.cat(ll, -1))
                UL.append(torch.cat(ul, -1))
                Mean.append(torch.cat(mean, -1))
        else:
            ci_stats[param]["LL"] = torch.cat(LL, -2).cpu()
            ci_stats[param]["UL"] = torch.cat(UL, -2).cpu()
            ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu()

    for param in global_params:
        if param == "pi":
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item()
            # Keq
            summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item()
            summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item()
            summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item()
        elif param == "trans":
            summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item()
            summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item()
            summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item()
            summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item()
            summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item()
            summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item()
        else:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item()
    ci_stats["m_probs"] = model.m_probs.data.cpu()
    ci_stats["theta_probs"] = model.theta_probs.data.cpu()
    ci_stats["z_probs"] = model.z_probs.data.cpu()
    ci_stats["z_map"] = model.z_map.data.cpu()

    # timestamps
    if model.data.time1 is not None:
        ci_stats["time1"] = model.data.time1
    if model.data.ttb is not None:
        ci_stats["ttb"] = model.data.ttb

    model.params = ci_stats

    # snr
    summary.loc["SNR", "Mean"] = (
        snr(
            model.data.images[:, :, model.cdx],
            ci_stats["width"]["Mean"],
            ci_stats["x"]["Mean"],
            ci_stats["y"]["Mean"],
            model.data.xy[:, :, model.cdx],
            ci_stats["background"]["Mean"],
            ci_stats["gain"]["Mean"],
            model.data.offset.mean,
            model.data.offset.var,
            model.data.P,
            model.theta_probs,
        )
        .mean()
        .item()
    )

    # classification statistics
    if model.data.labels is not None:
        pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel()
        true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel()

        with np.errstate(divide="ignore", invalid="ignore"):
            summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels)
        summary.loc["Recall", "Mean"] = recall_score(
            true_labels, pred_labels, zero_division=0
        )
        summary.loc["Precision", "Mean"] = precision_score(
            true_labels, pred_labels, zero_division=0
        )

        (
            summary.loc["TN", "Mean"],
            summary.loc["FP", "Mean"],
            summary.loc["FN", "Mean"],
            summary.loc["TP", "Mean"],
        ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel()

        mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx])
        samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask)
        if len(samples):
            z_ll, z_ul = hpdi(samples, CI)
            summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item()
            summary.loc["p(specific)", "95% LL"] = z_ll.item()
            summary.loc["p(specific)", "95% UL"] = z_ul.item()
        else:
            summary.loc["p(specific)", "Mean"] = 0.0
            summary.loc["p(specific)", "95% LL"] = 0.0
            summary.loc["p(specific)", "95% UL"] = 0.0

    model.summary = summary

    if path is not None:
        path = Path(path)
        torch.save(ci_stats, path / f"{model.full_name}-params.tpqr")
        if save_matlab:
            from scipy.io import savemat

            for param, field in ci_stats.items():
                if param in (
                    "m_probs",
                    "theta_probs",
                    "z_probs",
                    "z_map",
                    "time1",
                    "ttb",
                ):
                    ci_stats[param] = field.numpy()
                    continue
                for stat, value in field.items():
                    ci_stats[param][stat] = value.cpu().numpy()
            savemat(path / f"{model.full_name}-params.mat", ci_stats)
        summary.to_csv(
            path / f"{model.full_name}-summary.csv",
        )
Ejemplo n.º 12
0
            with open(prefix + 'AblationLoop{}/{}-predictive.pkl'.format(days, model_id), 'rb') as f:
                predictive = pickle.load(f)
        except Exception:
            continue
        predictive_list.append(predictive)

        with open(prefix + 'AblationLoop{}/{}-samples.pkl'.format(days, model_id), 'rb') as f:
            samples = pickle.load(f)
        samples_list.append(samples)

    val_window = 14

    seir_error_list = []

    for i in range(len(predictive_list)):
        seir_train = quantile(predictive_list[i]['prediction'].squeeze(), 0.5, dim=0)[-val_window + 1:].numpy()
        seir_train = np.diff(seir_train, axis=0)
        seir_label = data_dict['daily_death'][train_len - val_window:train_len, :].numpy()

        seir_error = np.abs(np.sum(seir_train, axis=0) - np.sum(seir_label, axis=0))
        seir_error_list.append(seir_error)

    seir_error = np.stack(seir_error_list, axis=0)
    best_model = np.argmin(seir_error, axis=0)

    test_len = 14
    best_error_list = []

    test_len = test_len - 1
    for j, i in zip(range(len(countries)), best_model):
        c = countries[j]
Ejemplo n.º 13
0
def save_stats(model, path, CI=0.95, save_matlab=False):
    # global parameters
    global_params = model._global_params
    summary = pd.DataFrame(
        index=global_params,
        columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"],
    )

    logger.info("- credible intervals")
    ci_stats = compute_ci(model, model.ci_params, CI)

    for param in global_params:
        if ci_stats[param]["Mean"].ndim == 0:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item()
        else:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].tolist()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].tolist()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].tolist()
    logger.info("- spot probabilities")
    ci_stats["m_probs"] = model.m_probs.data.cpu()
    ci_stats["theta_probs"] = model.theta_probs.data.cpu()
    ci_stats["z_probs"] = model.z_probs.data.cpu()
    ci_stats["z_map"] = model.z_map.data.cpu()
    ci_stats["p_specific"] = ci_stats["theta_probs"].sum(0)

    # calculate vmin/vmax
    theta_mask = ci_stats["theta_probs"] > 0.5
    if theta_mask.sum():
        hmax = np.percentile(ci_stats["height"]["Mean"][theta_mask], 99)
    else:
        hmax = 1
    ci_stats["height"]["vmin"] = -0.03 * hmax
    ci_stats["height"]["vmax"] = 1.3 * hmax
    ci_stats["width"]["vmin"] = 0.5
    ci_stats["width"]["vmax"] = 2.5
    ci_stats["x"]["vmin"] = -9
    ci_stats["x"]["vmax"] = 9
    ci_stats["y"]["vmin"] = -9
    ci_stats["y"]["vmax"] = 9
    bmax = np.percentile(ci_stats["background"]["Mean"].flatten(), 99)
    ci_stats["background"]["vmin"] = -0.03 * bmax
    ci_stats["background"]["vmax"] = 1.3 * bmax

    # timestamps
    if model.data.time1 is not None:
        ci_stats["time1"] = model.data.time1
    if model.data.ttb is not None:
        ci_stats["ttb"] = model.data.ttb

    # intensity of target-specific spots
    theta_mask = torch.argmax(ci_stats["theta_probs"], dim=0)
    #  h_specific = Vindex(ci_stats["height"]["Mean"])[
    #      theta_mask, torch.arange(model.data.Nt)[:, None], torch.arange(model.data.F)
    #  ]
    #  ci_stats["h_specific"] = h_specific * (ci_stats["z_map"] > 0).long()

    model.params = ci_stats

    logger.info("- SNR and Chi2-test")
    # snr and chi2 test
    snr = torch.zeros(model.K,
                      model.data.Nt,
                      model.data.F,
                      model.Q,
                      device=torch.device("cpu"))
    chi2 = torch.zeros(model.data.Nt,
                       model.data.F,
                       model.Q,
                       device=torch.device("cpu"))
    for n in range(model.data.Nt):
        snr[:, n], chi2[n] = snr_and_chi2(
            model.data.images[n],
            ci_stats["height"]["Mean"][:, n],
            ci_stats["width"]["Mean"][:, n],
            ci_stats["x"]["Mean"][:, n],
            ci_stats["y"]["Mean"][:, n],
            model.data.xy[n],
            ci_stats["background"]["Mean"][n],
            ci_stats["gain"]["Mean"],
            model.data.offset.mean,
            model.data.offset.var,
            model.data.P,
            ci_stats["theta_probs"][:, n],
        )
    snr_masked = snr[ci_stats["theta_probs"] > 0.5]
    summary.loc["SNR", "Mean"] = snr_masked.mean().item()
    ci_stats["chi2"] = {}
    ci_stats["chi2"]["values"] = chi2
    cmax = quantile(ci_stats["chi2"]["values"].flatten(), 0.99)
    ci_stats["chi2"]["vmin"] = -0.03 * cmax
    ci_stats["chi2"]["vmax"] = 1.3 * cmax

    # classification statistics
    if model.data.labels is not None:
        pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel()
        true_labels = model.data.labels["z"][:model.data.N].ravel()

        with np.errstate(divide="ignore", invalid="ignore"):
            summary.loc["MCC",
                        "Mean"] = matthews_corrcoef(true_labels, pred_labels)
        summary.loc["Recall", "Mean"] = recall_score(true_labels,
                                                     pred_labels,
                                                     zero_division=0)
        summary.loc["Precision", "Mean"] = precision_score(true_labels,
                                                           pred_labels,
                                                           zero_division=0)

        (
            summary.loc["TN", "Mean"],
            summary.loc["FP", "Mean"],
            summary.loc["FN", "Mean"],
            summary.loc["TP", "Mean"],
        ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel()

        mask = torch.from_numpy(model.data.labels["z"][:model.data.N]) > 0
        samples = torch.masked_select(
            model.z_probs[model.data.is_ontarget].cpu(), mask)
        if len(samples):
            z_ll, z_ul = hpdi(samples, CI)
            summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item()
            summary.loc["p(specific)", "95% LL"] = z_ll.item()
            summary.loc["p(specific)", "95% UL"] = z_ul.item()
        else:
            summary.loc["p(specific)", "Mean"] = 0.0
            summary.loc["p(specific)", "95% LL"] = 0.0
            summary.loc["p(specific)", "95% UL"] = 0.0

    model.summary = summary

    if path is not None:
        path = Path(path)
        param_path = path / f"{model.name}-params.tpqr"
        torch.save(ci_stats, param_path)
        logger.info(f"Parameters were saved in {param_path}")
        if save_matlab:
            from scipy.io import savemat

            for param, field in ci_stats.items():
                if param in (
                        "m_probs",
                        "theta_probs",
                        "z_probs",
                        "z_map",
                        "p_specific",
                        "h_specific",
                        "time1",
                        "ttb",
                ):
                    ci_stats[param] = np.asarray(field)
                    continue
                for stat, value in field.items():
                    ci_stats[param][stat] = np.asarray(value)
            mat_path = path / f"{model.name}-params.mat"
            savemat(mat_path, ci_stats)
            logger.info(f"Matlab parameters were saved in {mat_path}")
        csv_path = path / f"{model.name}-summary.csv"
        summary.to_csv(csv_path)
        logger.info(f"Summary statistics were saved in {csv_path}")
Ejemplo n.º 14
0
    
    seed = 42
    np.random.seed(seed)
    pyro.set_rng_seed(seed)
    
    with pyro.plate("respirator", 1000, dim=-2):
        surface_area, layer_params = respirator_model_charge_prior()
        results: torch.Tensor = penetration(surface_area, layer_params)

    plt.style.use('ggplot')
    plt.rcParams['text.usetex'] = True
    
    THRESHOLD = 6e-2  # 6% threshold

    # plt.plot(particle_diam, results)
    low_bound, high_bound = quantile(results, (0.05, 0.95))

    fig: plt.Figure = plt.figure()
    plt.plot(particle_diam, results.mean(dim=0), label="Mean")
    plt.fill_between(particle_diam, low_bound, high_bound, alpha=.4,
                    label=r"95\% confidence interval")
    plt.hlines(THRESHOLD, particle_diam.min(), particle_diam.max(), ls='--',
               label="Penetration threshold")
    plt.xlabel("Particle size $d_p$")
    plt.ylabel("Penetration")
    plt.xscale('log')
    plt.title("Penetration profile. Prior charge density $q \sim \\mathcal{N}(13, 1)$ nm")
    # plt.title("Prior charge density $q \sim \\mathcal{U}(13, 14)$ nm")
    plt.legend()
    plt.tight_layout()
    plt.show()
Ejemplo n.º 15
0
def test_pi():
    x = torch.randn(1000).exp()
    assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
Ejemplo n.º 16
0
    def get_R0(self, forecaster, Y_train, covariates_full, sample_size,
               batch_size):
        n_batch = sample_size // batch_size
        R_list = []
        map_list = []

        for i in range(n_batch):
            with torch.no_grad():
                with pyro.plate("particles", batch_size, dim=-3):
                    map_estimates = forecaster.guide(Y_train[:1, :],
                                                     covariates_full)
                map_list.append(map_estimates)
                data_dim = Y_train.size(-1)
                time_dim = covariates_full.size(-2)

                kernel_lengthscale = map_estimates['r0_lengthscale']
                kernel_var = map_estimates['r0_kernel_var']
                covariates_pyro = pyro_model.helper.reshape_covariates_pyro(
                    covariates_full, self.n_country).to(torch.float)

                kernel_lengthscale = kernel_lengthscale.transpose(-1, -2)
                kernel_var = kernel_var.transpose(-1, -2)

                # covariate
                covariates_pyro = covariates_pyro.transpose(1, 0)
                d_times_t = data_dim * time_dim
                # dxt, 1, p
                covariates_pyro = covariates_pyro.reshape(
                    d_times_t, 1, covariates_pyro.size(-1))

                var = pyro_model.helper.tensor_RBF(kernel_lengthscale,
                                                   kernel_var, covariates_pyro)
                var = var + torch.eye(var.shape[-1]) * 0.01
                A = torch.cholesky(var)

                iid_n = map_estimates['r0_iid_n']
                iid_n = torch.cat([
                    iid_n,
                    torch.randn(iid_n.size(0), iid_n.size(1),
                                time_dim - iid_n.size(-1))
                ],
                                  dim=-1)
                iid_n = iid_n.unsqueeze(-1)
                iid_n = iid_n.reshape(iid_n.size(0), 1, d_times_t, 1)

                weight = torch.sigmoid(
                    torch.einsum('abij,abjk->abik', A, iid_n))
                weight = weight[:, 0, ...]  # get rid of batch dimension
                weight = weight.reshape(weight.size(0), data_dim, time_dim, 1)
                R00 = map_estimates['R00']
                R0 = R00 * weight[..., 0]
                R_list.append(R0)

        R0 = torch.cat(R_list, dim=0)
        R0low, R0mid, R0high = quantile(R0, (0.1, 0.5, 0.9), dim=0).squeeze(-1)

        map_estimates = dict.fromkeys(map_list[0].keys())
        for k in map_list[0].keys():
            k_list = []
            for m in map_list:
                k_list.append(m[k])
            kest = torch.cat(k_list, dim=0)
            map_estimates[k] = kest

        return R0low, R0mid, R0high, map_estimates
Ejemplo n.º 17
0
days = 0
train_len = data_dict['cum_death'].shape[0] - days
covariates_actual = pyro_model.helper.get_covariates_intervention(data_dict, train_len, notime=True)
Y_train = pyro_model.helper.get_Y(data_dict, train_len)

seed = 1
model_id = 'day-{}-rng-{}'.format(days, seed)

with open(prefix + 'Loop{}/{}-predictive.pkl'.format(days, model_id), 'rb') as f:
    res = pickle.load(f)

with open(prefix + 'Loop{}/{}-forecaster.csv'.format(days, model_id), 'rb') as f:
    forecaster = pickle.load(f)

prediction = quantile(res['prediction'].squeeze(), (0.5,), dim=0).squeeze()

c = 0

dt_list = data_dict['date_list']
start_date = data_dict['t_init'][c] + pad

plt.plot(dt_list[:-days - 1], np.diff(prediction[:, c]), label='Fitted SEIR')
plt.plot(dt_list, data_dict['daily_death'][:, c], '.', label='acutal')

plt.gcf().autofmt_xdate()
plt.title(countries[c])
plt.legend()

R0 = res['R0'].squeeze()
Ejemplo n.º 18
0
def test_pi():
    x = torch.empty(1000).log_normal_(0, 1)
    assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
Ejemplo n.º 19
0
def _quantile(x, dim=0):
    return quantile(x, probs=[0.1, 0.6], dim=dim)
Ejemplo n.º 20
0
 def vmax(self) -> int:
     return quantile(self.images.flatten().float(), 0.99).item()