Пример #1
0
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    num_train: int,
) -> None:

    root = pathlib.Path("./data/seasonal")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    x_pred = posterior_predictive["x"]

    x_pred_trn = x_pred[:, :num_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    t_train = np.arange(num_train)

    x_pred_tst = x_pred[:, num_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)
    num_test = x_pred_tst.shape[1]
    t_test = np.arange(num_train, num_train + num_test)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(x.ravel(), label="ground truth", color=colors[0])

    plt.plot(t_train,
             x_pred_trn.mean(0)[:, 0],
             label="prediction",
             color=colors[1])
    plt.fill_between(t_train,
                     x_hpdi_trn[0, :, 0, 0],
                     x_hpdi_trn[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[1])

    plt.plot(t_test,
             x_pred_tst.mean(0)[:, 0],
             label="forecast",
             color=colors[2])
    plt.fill_between(t_test,
                     x_hpdi_tst[0, :, 0, 0],
                     x_hpdi_tst[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[2])

    plt.ylim(x.min() - 0.5, x.max() + 0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(root / "prediction.png")
    plt.close()
Пример #2
0
def int_quantize_jit(x: jnp.ndarray, max_int: int, to_type: str):
    min = x.min(axis=1, keepdims=True)
    max = x.max(axis=1, keepdims=True)

    offset = min
    scale = max - min

    normalized = (x - min) / scale
    return offset, scale, (normalized * max_int + 0.5).astype(to_type)  # round to nearest instead of round to zero