def save(self, fp): # Each attribute that we need to save goes in this dictionary p = self._flatten_params() if len(p) == 0: raise ValueError( 'Trying to save but no parameters were registered') jnp.savez(fp, **p)
def save_samples(filename, prior_samples, mcmc_samples, post_pred_samples, forecast_samples): np.savez(filename, prior_samples=prior_samples, mcmc_samples=mcmc_samples, post_pred_samples=post_pred_samples, forecast_samples=forecast_samples)
def _save_results( y: np.ndarray, mcmc: infer.MCMC, prior: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], *, var_names: Optional[List[str]] = None, ) -> None: root = pathlib.Path("./data/boston_pca_reg") root.mkdir(exist_ok=True) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) # Arviz numpyro_data = az.from_numpyro( mcmc, prior=prior, posterior_predictive=posterior_predictive, ) az.plot_trace(numpyro_data, var_names=var_names) plt.savefig(root / "trace.png") plt.close() az.plot_ppc(numpyro_data) plt.legend(loc="upper right") plt.savefig(root / "ppc.png") plt.close() # Prediction y_pred = posterior_predictive["y"] y_hpdi = diagnostics.hpdi(y_pred) train_len = int(len(y) * 0.8) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(y, color=colors[0]) plt.plot(y_pred.mean(axis=0), color=colors[1]) plt.fill_between(np.arange(len(y)), y_hpdi[0], y_hpdi[1], color=colors[1], alpha=0.3) plt.axvline(train_len, linestyle="--", color=colors[2]) plt.xlabel("Index [a.u.]") plt.ylabel("Target [a.u.]") plt.savefig(root / "prediction.png") plt.close()
def get_estimate(self, data, kwargs, fit_kwargs=None, state=False, validate=True, fit=False, set=True): _kwargs = copy.deepcopy(kwargs) _kwargs = self.preload(_kwargs, state=state, validate=validate) imnn = self.imnn(**_kwargs) if not set: with pytest.raises(ValueError) as info: imnn.get_estimate(data) assert info.match( re.escape( "Fisher information has not yet been calculated. Please " + "run `imnn.set_F_statistics({w}, {key}, {validate}) " + "with `w = imnn.final_w`, `w = imnn.best_w`, " + "`w = imnn.inital_w` or otherwise, `validate = True` " + "should be set if not simulating on the fly.")) return _fit_kwargs = copy.deepcopy(fit_kwargs) λ = _fit_kwargs.pop("λ") ϵ = _fit_kwargs.pop("ϵ") if fit: imnn.fit(λ, ϵ, **_fit_kwargs) else: imnn.set_F_statistics(key=self.stats_key) estimate = imnn.get_estimate(data) name = self.set_name(state, validate, fit, _kwargs["n_d"]) if self.save: files = {f"{name}estimate": estimate} try: targets = np.load(f"test/{self.filename}.npz") files = {k: files[k] for k in files.keys() - targets.keys()} np.savez(f"test/{self.filename}.npz", **{**files, **targets}) except Exception: np.savez(f"test/{self.filename}.npz", **files) targets = np.load(f"test/{self.filename}.npz") assert np.all(np.equal(estimate, targets[f"{name}estimate"]))
def save_test_embeddings(key, experiment, save_path, n_samples_per_batch=64): # Load the full datset exp, sampler, encoder, decoder = experiment n_train, n_test, n_validation = exp.split_shapes data_loader = exp.data_loader # Compute our embeddings x, y = data_loader((n_test + n_validation, ), start=0, split='tpv', return_labels=True, onehot=False) z, _ = batched_evaluate(key, encoder, x, n_samples_per_batch) # Compute the UMAP embeddings u = umap.UMAP(random_state=0).fit_transform(z, y=y) z, y = np.array(z), np.array(y) np.savez(save_path, z=z, y=y, u=u)
def main(m, seed): rng_key = random.PRNGKey(seed) cutoff_up = 800 cutoff_down = 100 if m <= 10: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m) else: seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10) model = generative_model nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000) seq = (seq['beliefs'][0][cutoff_down:cutoff_up], seq['beliefs'][1][cutoff_down:cutoff_up]) rng_key, _rng_key = random.split(rng_key) mcmc.run( _rng_key, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool), extra_fields=('potential_energy',) ) samples = mcmc.get_samples() waic = log_pred_density( model, samples, seq, y=responses_data[cutoff_down:cutoff_up], mask=mask_data[cutoff_down:cutoff_up].astype(bool) )['waic'] jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic) print(mcmc.get_extra_fields()['potential_energy'].mean())
def main(): args = get_arguments() project = args.project n_qubits, g, h = args.n_qubits, 2, 0 filters = dict(n_qubits=n_qubits, g=g, h=h) if args.n_layers_list is not None: print(f'Add [n_layers] filter: {args.n_layers_list}') filters['n_layers'] = {'$in': args.n_layers_list} if args.lr is not None: print(f'Add [lr] filter: {args.lr}') filters['lr'] = args.lr resdir = Path('results_hessian') opt_circuits = download_circuits(project, filters, resdir) ham_matrix = qnnops.ising_hamiltonian(n_qubits, g, h) print('Hessian spectrum') hess_fns = {} for cfg, fpath in opt_circuits: print(f'| computing hessian of {fpath}') params = jnp.load(fpath) circuit_name = f'Q{n_qubits}-L{cfg["n_layers"]}-R{cfg["rot_axis"]}' if circuit_name not in hess_fns: loss_fn = get_loss_fn( ham_matrix, n_qubits, cfg['n_layers'], cfg['rot_axis']) hess_fns[circuit_name] = jax.hessian(loss_fn) _, hess_eigvals = compute_hessian_eigenvalues(hess_fns[circuit_name], params) name = get_normalized_name(cfg) jnp.savez( resdir / f'{name}_all.npz', params=params, ham_matrix=ham_matrix, hess_spectrum=hess_eigvals, ) print(f'| ...plotting hessian spectrum and histogram') plot_spectrum(hess_eigvals, resdir / f'{name}_spectrum.pdf') plot_histogram(hess_eigvals, resdir / f'{name}_histogram.pdf')
def _save_results( x: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], num_train: int, ) -> None: root = pathlib.Path("./data/seasonal") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) x_pred = posterior_predictive["x"] x_pred_trn = x_pred[:, :num_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) t_train = np.arange(num_train) x_pred_tst = x_pred[:, num_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) num_test = x_pred_tst.shape[1] t_test = np.arange(num_train, num_train + num_test) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(x.ravel(), label="ground truth", color=colors[0]) plt.plot(t_train, x_pred_trn.mean(0)[:, 0], label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0, :, 0, 0], x_hpdi_trn[1, :, 0, 0], alpha=0.3, color=colors[1]) plt.plot(t_test, x_pred_tst.mean(0)[:, 0], label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0, :, 0, 0], x_hpdi_tst[1, :, 0, 0], alpha=0.3, color=colors[2]) plt.ylim(x.min() - 0.5, x.max() + 0.5) plt.legend() plt.tight_layout() plt.savefig(root / "prediction.png") plt.close()
def main() -> None: df = load_dataset() test_index = 80 test_len = len(df) - test_index y_train = jnp.array(df.loc[:test_index, "value"], dtype=jnp.float32) # Inference kernel = NUTS(sgt) mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1) mcmc.run(random.PRNGKey(0), y_train, seasonality=38) mcmc.print_summary() posterior_samples = mcmc.get_samples() # Prediction predictive = Predictive(sgt, posterior_samples, return_sites=["y_forecast"]) posterior_predictive = predictive(random.PRNGKey(1), y_train, seasonality=38, future=test_len) root = pathlib.Path("./data/time_series") root.mkdir(exist_ok=True) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) plot_results( df["time"].values, df["value"].values, posterior_samples, posterior_predictive, test_index, root, )
def _save_results( x: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], ) -> None: root = pathlib.Path("./data/kalman") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) len_train = x.shape[0] x_pred_trn = posterior_predictive["x"][:, :len_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) x_pred_tst = posterior_predictive["x"][:, len_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) len_test = x_pred_tst.shape[1] prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] plt.figure(figsize=(12, 6)) plt.plot(x.ravel(), label="ground truth", color=colors[0]) t_train = np.arange(len_train) plt.plot(t_train, x_pred_trn.mean(0).ravel(), label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0].ravel(), x_hpdi_trn[1].ravel(), alpha=0.3, color=colors[1]) t_test = np.arange(len_train, len_train + len_test) plt.plot(t_test, x_pred_tst.mean(0).ravel(), label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0].ravel(), x_hpdi_tst[1].ravel(), alpha=0.3, color=colors[2]) plt.legend() plt.tight_layout() plt.savefig(root / "kalman.png") plt.close()
def main(args: argparse.Namespace) -> None: model = model_dict[args.model] _, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False) lengths, sequences = fetch() # Remove never used data dimension to reduce computation time present_notes = (sequences == 1).sum(0).sum(0) > 0 sequences = sequences[..., present_notes] batch, seq_len, data_dim = sequences.shape rng_key = random.PRNGKey(0) rng_key, rng_key_prior, rng_key_pred = random.split(rng_key, 3) predictive = infer.Predictive(model, num_samples=10) prior_samples = predictive(rng_key_prior, batch=batch, seq_len=seq_len, data_dim=data_dim, future_steps=20) kernel = infer.NUTS(model) mcmc = infer.MCMC(kernel, args.num_warmup, args.num_samples, args.num_chains) mcmc.run(rng_key, sequences, lengths) posterior_samples = mcmc.get_samples() predictive = infer.Predictive(model, posterior_samples) predictive_samples = predictive(rng_key_pred, sequences, lengths, future_steps=10) path = pathlib.Path("./data/hmm_enum") path.mkdir(exist_ok=True) jnp.savez(path / "prior_samples.npz", **prior_samples) jnp.savez(path / "posterior_samples.npz", **posterior_samples) jnp.savez(path / "predictive_samples.npz", **predictive_samples)
def save_params(path: Union[str, pathlib.Path], params: Dict[str, jnp.ndarray]) -> None: jnp.savez(path, **params)
def main(m, seed, device, dynamic_gamma, dynamic_preference): import jax.numpy as jnp from numpyro.infer import MCMC, NUTS from jax import random, nn, vmap, jit, devices, device_put # import utility functions for model inversion and belief estimation from utils import estimate_beliefs, single_model, log_pred_density # import data loader from stats import load_data outcomes_data, responses_data, mask_data, ns, _, _ = load_data() mask_data = jnp.array(mask_data) responses_data = jnp.array(responses_data).astype(jnp.int32) outcomes_data = jnp.array(outcomes_data).astype(jnp.int32) rng_key = random.PRNGKey(seed) cutoff_up = 1000 cutoff_down = 400 print(m, seed, dynamic_gamma, dynamic_preference) if m <= 10: seq, _ = estimate_beliefs(outcomes_data, responses_data, device, mask=mask_data, nu_max=m) else: seq, _ = estimate_beliefs(outcomes_data, responses_data, device, mask=mask_data, nu_max=10, nu_min=m-10) model = lambda *args: single_model(*args, dynamic_gamma=dynamic_gamma, dynamic_preference=dynamic_preference) # init preferences c0 = jnp.sum(nn.one_hot(outcomes_data[:cutoff_down], 4) * jnp.expand_dims(mask_data[:cutoff_down], -1), 0) if dynamic_gamma: num_warmup = 500 num_samples = 500 num_chains = 2 else: num_warmup = 100 num_samples = 100 num_chains = 10 def inference(belief_sequences, obs, mask, rng_key): nuts_kernel = NUTS(model, dense_mass=True) mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False) mcmc.run( rng_key, belief_sequences, obs, mask, extra_fields=('potential_energy',) ) samples = mcmc.get_samples() potential_energy = mcmc.get_extra_fields()['potential_energy'].mean() # mcmc.print_summary() return samples, potential_energy seq = ( seq['beliefs'][0][cutoff_down:cutoff_up], seq['beliefs'][1][cutoff_down:cutoff_up], outcomes_data[cutoff_down:cutoff_up], c0 ) y = responses_data[cutoff_down:cutoff_up] mask = mask_data[cutoff_down:cutoff_up].astype(bool) n = mask.shape[-1] rng_keys = random.split(rng_key, n) samples, potential_energy = jit(vmap(inference, in_axes=((1, 1, 1, 0), 1, 1, 0)))(seq, y, mask, rng_keys) waic = vmap(lambda *args: log_pred_density(model, *args), in_axes=(0, (1, 1, 1, 0), 1, 1)) waic, log_likelihood = waic(samples, seq, y, mask) print('waic', waic.mean()) jnp.savez('fit_data/fit_waic_sample_minf{}_gamma{}_pref{}_long.npz'.format(m, int(dynamic_gamma), int(dynamic_preference)), samples=samples, waic=waic, log_likelihood=log_likelihood)
def save_history(filename, history, upload_to_wandb=True): """ Save jax array zip. """ filepath = str(get_result_path(filename)) jnp.savez(filepath, **history) if upload_to_wandb: safe_wandb_save(filepath)
def combined_running_test(self, data, kwargs, fit_kwargs, state=False, validate=True, fit=False, none_first=True, implemented=True, aggregated=False): _kwargs = copy.deepcopy(kwargs) _kwargs = self.preload(_kwargs, state=state, validate=validate) imnn = self.imnn(**_kwargs) with pytest.raises(ValueError) as info: imnn.get_estimate(data) assert info.match( re.escape( "Fisher information has not yet been calculated. Please " + "run `imnn.set_F_statistics({w}, {key}, {validate}) " + "with `w = imnn.final_w`, `w = imnn.best_w`, " + "`w = imnn.inital_w` or otherwise, `validate = True` " + "should be set if not simulating on the fly.")) time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") name = self.set_name(state, validate, fit, _kwargs["n_d"]) if fit: _fit_kwargs = copy.deepcopy(fit_kwargs) λ = _fit_kwargs.pop("λ") ϵ = _fit_kwargs.pop("ϵ") if not implemented: with pytest.raises(ValueError) as info: imnn.fit(λ, ϵ, **_fit_kwargs) assert info.match("`get_summaries` not implemented") else: if none_first: _fit_kwargs["print_rate"] = None string = ("Cannot run IMNN with progress bar after " + "running without progress bar. Either set " + "`print_rate` to None or reinitialise the IMNN.") else: _fit_kwargs["print_rate"] = 100 string = ("Cannot run IMNN without progress bar after " + "running with progress bar. Either set " + "`print_rate` to an int or reinitialise the " + "IMNN.") imnn.fit(λ, ϵ, **_fit_kwargs) if none_first: _fit_kwargs["print_rate"] = 100 else: _fit_kwargs["print_rate"] = None else: if not implemented: with pytest.raises(ValueError) as info: imnn.set_F_statistics(key=self.stats_key) assert info.match("`get_summaries` not implemented") else: imnn.set_F_statistics(key=self.stats_key) if implemented: estimates = [] for element in data: estimates.append(imnn.get_estimate(element)) if self.save: files = { f"{name}estimates": estimates, f"{name}F": imnn.F, f"{name}C": imnn.C, f"{name}invC": imnn.invC, f"{name}dμ_dθ": imnn.dμ_dθ, f"{name}μ": imnn.μ, f"{name}invF": imnn.invF } try: targets = np.load(f"test/{self.filename}.npz") files = { k: files[k] for k in files.keys() - targets.keys() } np.savez(f"test/{self.filename}.npz", **{ **files, **targets }, allow_pickle=True) except Exception: np.savez(f"test/{self.filename}.npz", **files, allow_pickle=True) targets = np.load(f"test/{self.filename}.npz", allow_pickle=True) for i, estimate in enumerate(estimates): assert np.all( np.equal(estimate, targets[f"{name}estimates"][i])) assert np.all(np.equal(imnn.F, targets[f"{name}F"])) assert np.all(np.equal(imnn.C, targets[f"{name}C"])) assert np.all(np.equal(imnn.invC, targets[f"{name}invC"])) assert np.all(np.equal(imnn.dμ_dθ, targets[f"{name}dμ_dθ"])) assert np.all(np.equal(imnn.μ, targets[f"{name}μ"])) assert np.all(np.equal(imnn.invF, targets[f"{name}invF"])) imnn.plot(expected_detF=50, filename=f"test/figures/{self.filename}/{name}_{time}.pdf") plt.close("all") assert pathlib.Path( f"test/figures/{self.filename}/{name}_{time}.pdf").is_file() if fit and implemented and (not aggregated): with pytest.raises(ValueError) as info: imnn.fit(λ, ϵ, **_fit_kwargs) assert info.match(string)
def fit(self, kwargs, fit_kwargs, state=False, validate=True, set=True, none_first=True): _kwargs = copy.deepcopy(kwargs) _kwargs = self.preload(_kwargs, state=state, validate=validate) _fit_kwargs = copy.deepcopy(fit_kwargs) λ = _fit_kwargs.pop("λ") ϵ = _fit_kwargs.pop("ϵ") imnn = self.imnn(**_kwargs) if not set: with pytest.raises(ValueError) as info: self.imnn(**_kwargs).fit(λ, ϵ, **_fit_kwargs) assert info.match("`get_summaries` not implemented") return if none_first: _fit_kwargs["print_rate"] = None string = ("Cannot run IMNN with progress bar after running " + "without progress bar. Either set `print_rate` to " + "None or reinitialise the IMNN.") else: _fit_kwargs["print_rate"] = 100 string = ("Cannot run IMNN without progress bar after running " + "with progress bar. Either set `print_rate` to an int " + "or reinitialise the IMNN.") imnn.fit(λ, ϵ, **_fit_kwargs) name = self.set_name(state, validate, False, _kwargs["n_d"]) if self.save: files = { f"{name}F": imnn.F, f"{name}C": imnn.C, f"{name}invC": imnn.invC, f"{name}dμ_dθ": imnn.dμ_dθ, f"{name}μ": imnn.μ, f"{name}invF": imnn.invF, } try: targets = np.load(f"test/{self.filename}.npz") files = {k: files[k] for k in files.keys() - targets.keys()} np.savez(f"test/{self.filename}.npz", **{**files, **targets}) except Exception: np.savez(f"test/{self.filename}.npz", **files) targets = np.load(f"test/{self.filename}.npz") assert np.all(np.equal(imnn.F, targets[f"{name}F"])) assert np.all(np.equal(imnn.C, targets[f"{name}C"])) assert np.all(np.equal(imnn.invC, targets[f"{name}invC"])) assert np.all(np.equal(imnn.dμ_dθ, targets[f"{name}dμ_dθ"])) assert np.all(np.equal(imnn.μ, targets[f"{name}μ"])) assert np.all(np.equal(imnn.invF, targets[f"{name}invF"])) if none_first: _fit_kwargs["print_rate"] = 100 else: _fit_kwargs["print_rate"] = None with pytest.raises(ValueError) as info: imnn(**_kwargs).fit(λ, ϵ, **_fit_kwargs) assert info.match(string) return
def main(seed, device, dynamic_gamma, dynamic_preference, mc_type): import jax.numpy as jnp from numpyro.infer import MCMC, NUTS from jax import random, nn, devices, device_put, vmap, jit # import utility functions for model inversion and belief estimation from utils import estimate_beliefs, mixture_model # import data loader from stats import load_data outcomes_data, responses_data, mask_data, ns, _, _ = load_data() print(seed, device, dynamic_gamma, dynamic_preference) model = lambda *args: mixture_model(*args, dynamic_gamma=dynamic_gamma, dynamic_preference=dynamic_preference) m_data = jnp.array(mask_data) r_data = jnp.array(responses_data).astype(jnp.int32) o_data = jnp.array(outcomes_data).astype(jnp.int32) rng_key = random.PRNGKey(seed) cutoff_up = 1000 cutoff_down = 400 priors = [] params = [] if mc_type == 'nu_max': M_rng = list(range(1, 11)) # model comparison for regular condition else: M_rng = [1,] + list(range(11, 20)) # model comparison for irregular condition for M in M_rng: if M <= 10: seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=M) else: seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=10, nu_min=M-10) priors.append(seq['beliefs'][0][cutoff_down:cutoff_up]) params.append(seq['beliefs'][1][cutoff_down:cutoff_up]) device = devices(device)[0] # init preferences c0 = jnp.sum(nn.one_hot(outcomes_data[:cutoff_down], 4) * jnp.expand_dims(mask_data[:cutoff_down], -1), 0) if dynamic_gamma: num_warmup = 1000 num_samples = 1000 num_chains = 1 else: num_warmup = 200 num_samples = 200 num_chains = 5 def inference(belief_sequences, obs, mask, rng_key): nuts_kernel = NUTS(model, dense_mass=True) mcmc = MCMC( nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False ) mcmc.run( rng_key, belief_sequences, obs, mask, extra_fields=('potential_energy',) ) samples = mcmc.get_samples() potential_energy = mcmc.get_extra_fields()['potential_energy'].mean() # mcmc.print_summary() return samples, potential_energy seqs = device_put( ( jnp.stack(priors, 0), jnp.stack(params, 0), o_data[cutoff_down:cutoff_up], c0 ), device) y = device_put(r_data[cutoff_down:cutoff_up], device) mask = device_put(m_data[cutoff_down:cutoff_up].astype(bool), device) n = mask.shape[-1] rng_keys = random.split(rng_key, n) samples, potential_energy = jit(vmap(inference, in_axes=((2, 2, 1, 0), 1, 1, 0)))(seqs, y, mask, rng_keys) print('potential_energy', potential_energy) jnp.savez('fit_data/fit_sample_mixture_gamma{}_pref{}_{}.npz'.format(int(dynamic_gamma), int(dynamic_preference), mc_type), samples=samples)
def savez(file, *args, **kwds): args = [_remove_jaxarray(a) for a in args] kwds = {k: _remove_jaxarray(v) for k, v in kwds.items()} jnp.savez(file, *args, **kwds)
def _save_results( x: jnp.ndarray, betas: jnp.ndarray, prior_samples: Dict[str, jnp.ndarray], posterior_samples: Dict[str, jnp.ndarray], posterior_predictive: Dict[str, jnp.ndarray], num_train: int, ) -> None: root = pathlib.Path("./data/dlm") root.mkdir(exist_ok=True) jnp.savez(root / "piror_samples.npz", **prior_samples) jnp.savez(root / "posterior_samples.npz", **posterior_samples) jnp.savez(root / "posterior_predictive.npz", **posterior_predictive) x_pred = posterior_predictive["x"] x_pred_trn = x_pred[:, :num_train] x_hpdi_trn = diagnostics.hpdi(x_pred_trn) t_train = np.arange(num_train) x_pred_tst = x_pred[:, num_train:] x_hpdi_tst = diagnostics.hpdi(x_pred_tst) num_test = x_pred_tst.shape[1] t_test = np.arange(num_train, num_train + num_test) t_axis = np.arange(num_train + num_test) w_pred = posterior_predictive["weight"] w_hpdi = diagnostics.hpdi(w_pred) prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] beta_dim = betas.shape[-1] plt.figure(figsize=(8, 12)) for i in range(beta_dim + 1): plt.subplot(beta_dim + 1, 1, i + 1) if i == 0: plt.plot(x[:, 0], label="ground truth", color=colors[0]) plt.plot(t_train, x_pred_trn.mean(0)[:, 0], label="prediction", color=colors[1]) plt.fill_between(t_train, x_hpdi_trn[0, :, 0, 0], x_hpdi_trn[1, :, 0, 0], alpha=0.3, color=colors[1]) plt.plot(t_test, x_pred_tst.mean(0)[:, 0], label="forecast", color=colors[2]) plt.fill_between(t_test, x_hpdi_tst[0, :, 0, 0], x_hpdi_tst[1, :, 0, 0], alpha=0.3, color=colors[2]) plt.title("ground truth", fontsize=16) else: plt.plot(betas[:, 0, i - 1], label="ground truth", color=colors[0]) plt.plot(t_axis, w_pred.mean(0)[:, i - 1], label="prediction", color=colors[1]) plt.fill_between(t_axis, w_hpdi[0, :, i - 1, 0], w_hpdi[1, :, i - 1, 0], alpha=0.3, color=colors[1]) plt.title(f"coef_{i - 1}", fontsize=16) plt.legend(loc="upper left") plt.tight_layout() plt.savefig(root / "prediction.png") plt.close()
def test_and_plot_collocation_svgp_quad(): gp = FakeSVGPQuad() metric = FakeSVGPMetricQuad() solver = FakeCollocationSVGPQuad() params = { # 'axes.labelsize': 30, # 'font.size': 30, # 'legend.fontsize': 20, # 'xtick.labelsize': 30, # 'ytick.labelsize': 30, "text.usetex": True, } plt.rcParams.update(params) dir_name = create_save_dir() fig, ax = plot_mixing_prob_and_start_end(gp, solver, solver.state_guesses) # plot_traj(fig, ax, solver.state_guesses) save_name = dir_name + "mixing_prob_2d_init_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) fig, axs = plot_svgp_and_start_end(gp, solver) # plot_traj(fig, axs, solver.state_guesses) save_name = dir_name + "svgp_2d_init_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) fig, axs = plot_epistemic_var_vs_time( gp, solver, traj_init=solver.state_guesses, traj_opt=None ) save_name = dir_name + "epistemic_var_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) plot_aleatoric_var_vs_time( gp, solver, traj_init=solver.state_guesses, traj_opt=None ) save_name = dir_name + "aleatoric_var_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) fig, ax = plot_svgp_metric_trace_and_start_end(metric, solver) plot_traj(fig, ax, solver.state_guesses) save_name = dir_name + "metric_trace_2d_init_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) fig, axs = plot_svgp_jacobian_mean( gp, solver, traj_opt=solver.state_guesses ) fig, axs = plot_svgp_jacobian_var( gp, solver, traj_opt=solver.state_guesses ) fig, axs = plot_svgp_metric_and_start_end( metric, solver, traj_opt=solver.state_guesses ) # # prob vs time # plot_mixing_prob_over_time(gp, # solver, # traj_init=solver.state_guesses, # traj_opt=solver.state_guesses) # save_name = dir_name + "mixing_prob_vs_time.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # plot_metirc_trace_over_time(metric, # solver, # traj_init=solver.state_guesses, # traj_opt=solver.state_guesses) # save_name = dir_name + "metric_trace_vs_time.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, axs = plot_3d_mean_and_var(gp, solver) # plot_3d_traj_mean_and_var(fig, axs, gp, traj=solver.state_guesses) # save_name = dir_name + "init_traj_mean_and_var.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, ax = plot_3d_metric_trace(metric, solver) # plot_3d_traj_metric_trace(fig, ax, metric, traj=solver.state_guesses) # save_name = dir_name + "init_traj_metric_trace.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, ax = plot_3d_mixing_prob(gp, solver) # plot_3d_traj_mixing_prob(fig, ax, gp, traj=solver.state_guesses) # save_name = dir_name + "init_traj_mixing_prob.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) plt.show() geodesic_traj = test_collocation_svgp_quad() # fig, ax = plot_3d_metric_trace(metric, solver) # plot_3d_traj_metric_trace(fig, ax, metric, traj=geodesic_traj) # save_name = dir_name + "opt_traj_metric_trace.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, axs = plot_3d_mean_and_var(gp, solver) # plot_3d_traj_mean_and_var(fig, axs, gp, geodesic_traj) # save_name = dir_name + "opt_traj_mean_and_var.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, ax = plot_3d_mixing_prob(gp, solver) # plot_3d_traj_mixing_prob(fig, ax, gp, traj=geodesic_traj) # save_name = dir_name + "opt_traj_mixing_prob.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, axs = plot_svgp_metric_trace_and_start_end(metric, solver) # fig, axs = plot_traj(fig, axs, geodesic_traj) # fig, axs = plot_svgp_and_start_end(gp, solver) # fig, axs = plot_traj(fig, axs, geodesic_traj) # plt.show() # plot init and opt trajectories over svgp fig, axs = plot_svgp_and_start_end(gp, solver, traj_opt=geodesic_traj) save_name = dir_name + "svgp_2d_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) # init and opt trajectories over mixing probability fig, ax = plot_mixing_prob_and_start_end( gp, solver, traj_opt=geodesic_traj ) save_name = dir_name + "mixing_prob_2d_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) # init and opt trajectories over metric trace fig, ax = plot_svgp_metric_trace_and_start_end( metric, solver, traj_opt=geodesic_traj ) save_name = dir_name + "metric_trace_2d_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) # prob vs time plot_mixing_prob_over_time( gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj ) save_name = dir_name + "mixing_prob_vs_time.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) # plot metric trace vs time plot_metirc_trace_over_time( metric, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj ) save_name = dir_name + "metric_trace_vs_time.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) # plot epistemic uncertainty vs time fig, axs = plot_epistemic_var_vs_time( gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj ) save_name = dir_name + "epistemic_var_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) plot_aleatoric_var_vs_time( gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj ) save_name = dir_name + "aleatoric_var_traj.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) save_name = dir_name + "geodesic_traj.npz" np.savez(save_name, geodesic_traj) # fig, ax = plot_3d_mixing_prob(gp, solver) # plot_3d_traj_mixing_prob_init_and_opt(fig, # ax, # gp, # traj_init=solver.state_guesses, # traj_opt=geodesic_traj) # save_name = dir_name + "traj_mixing_prob.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) # fig, ax = plot_3d_metric_trace(metric, solver) # plot_3d_traj_metric_trace_init_and_opt(fig, # ax, # metric, # traj_init=solver.state_guesses, # traj_opt=geodesic_traj) # save_name = dir_name + "traj_metric_trace.pdf" # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0) plt.show()
params.append(seq_sim[1].copy()) c0 = np.sum( nn.one_hot(outcomes_sim[cutoff:], 4).copy().astype(float), 0) # fit simulated data seqs = device_put((np.stack(priors, 0), np.stack( params, 0), outcomes_sim[cutoff:].copy(), c0), device) y = device_put(responses_sim[cutoff:].copy(), device) mask = jnp.ones_like(y).astype(bool) rng_keys = random.split(rng_key, N) samples = jit(vmap(inference, in_axes=((2, 2, 1, 0), 1, 1, 0)))(seqs, y, mask, rng_keys) post_smpl[m_true] = samples jnp.savez('fit_sims/tmp_sims_m{}.npz'.format(m_true), samples=samples) del seq_sim, samples, agent # save posterior estimates jnp.savez('fit_sims/sims_mcomp_numin_P-{}-{}-{}-{}.npz'.format(*P_o), samples=post_smpl) # delete tmp files for m_true in m_rng: os.remove('fit_sims/tmp_sims_m{}.npz'.format(m_true))