def _save_results( x: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], num_train: int, ) -> None: root = pathlib.Path("./data/seasonal") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) x_pred = posterior_predictive["x"] x_pred_trn = x_pred[:, :num_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) t_train = np.arange(num_train) x_pred_tst = x_pred[:, num_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) num_test = x_pred_tst.shape[1] t_test = np.arange(num_train, num_train + num_test) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(x.ravel(), label="ground truth", color=colors[0]) plt.plot(t_train, x_pred_trn.mean(0)[:, 0], label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0, :, 0, 0], x_hpdi_trn[1, :, 0, 0], alpha=0.3, color=colors[1]) plt.plot(t_test, x_pred_tst.mean(0)[:, 0], label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0, :, 0, 0], x_hpdi_tst[1, :, 0, 0], alpha=0.3, color=colors[2]) plt.ylim(x.min() - 0.5, x.max() + 0.5) plt.legend() plt.tight_layout() plt.savefig(root / "prediction.png") plt.close()
def int_quantize_jit(x: jnp.ndarray, max_int: int, to_type: str): min = x.min(axis=1, keepdims=True) max = x.max(axis=1, keepdims=True) offset = min scale = max - min normalized = (x - min) / scale return offset, scale, (normalized * max_int + 0.5).astype(to_type) # round to nearest instead of round to zero