コード例 #1
0
ファイル: base_agent.py プロジェクト: winston-ds/jax-rl
 def save(self, fp):
     # Each attribute that we need to save goes in this dictionary
     p = self._flatten_params()
     if len(p) == 0:
         raise ValueError(
             'Trying to save but no parameters were registered')
     jnp.savez(fp, **p)
コード例 #2
0
ファイル: util.py プロジェクト: aleccrowell/covid
def save_samples(filename, prior_samples, mcmc_samples, post_pred_samples,
                 forecast_samples):

    np.savez(filename,
             prior_samples=prior_samples,
             mcmc_samples=mcmc_samples,
             post_pred_samples=post_pred_samples,
             forecast_samples=forecast_samples)
コード例 #3
0
def _save_results(
    y: np.ndarray,
    mcmc: infer.MCMC,
    prior: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    *,
    var_names: Optional[List[str]] = None,
) -> None:

    root = pathlib.Path("./data/boston_pca_reg")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    # Arviz
    numpyro_data = az.from_numpyro(
        mcmc,
        prior=prior,
        posterior_predictive=posterior_predictive,
    )

    az.plot_trace(numpyro_data, var_names=var_names)
    plt.savefig(root / "trace.png")
    plt.close()

    az.plot_ppc(numpyro_data)
    plt.legend(loc="upper right")
    plt.savefig(root / "ppc.png")
    plt.close()

    # Prediction
    y_pred = posterior_predictive["y"]
    y_hpdi = diagnostics.hpdi(y_pred)
    train_len = int(len(y) * 0.8)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(y, color=colors[0])
    plt.plot(y_pred.mean(axis=0), color=colors[1])
    plt.fill_between(np.arange(len(y)),
                     y_hpdi[0],
                     y_hpdi[1],
                     color=colors[1],
                     alpha=0.3)
    plt.axvline(train_len, linestyle="--", color=colors[2])
    plt.xlabel("Index [a.u.]")
    plt.ylabel("Target [a.u.]")
    plt.savefig(root / "prediction.png")
    plt.close()
コード例 #4
0
    def get_estimate(self,
                     data,
                     kwargs,
                     fit_kwargs=None,
                     state=False,
                     validate=True,
                     fit=False,
                     set=True):
        _kwargs = copy.deepcopy(kwargs)
        _kwargs = self.preload(_kwargs, state=state, validate=validate)

        imnn = self.imnn(**_kwargs)

        if not set:
            with pytest.raises(ValueError) as info:
                imnn.get_estimate(data)
            assert info.match(
                re.escape(
                    "Fisher information has not yet been calculated. Please " +
                    "run `imnn.set_F_statistics({w}, {key}, {validate}) " +
                    "with `w = imnn.final_w`, `w = imnn.best_w`, " +
                    "`w = imnn.inital_w` or otherwise, `validate = True` " +
                    "should be set if not simulating on the fly."))
            return

        _fit_kwargs = copy.deepcopy(fit_kwargs)
        λ = _fit_kwargs.pop("λ")
        ϵ = _fit_kwargs.pop("ϵ")
        if fit:
            imnn.fit(λ, ϵ, **_fit_kwargs)
        else:
            imnn.set_F_statistics(key=self.stats_key)
        estimate = imnn.get_estimate(data)

        name = self.set_name(state, validate, fit, _kwargs["n_d"])
        if self.save:
            files = {f"{name}estimate": estimate}
            try:
                targets = np.load(f"test/{self.filename}.npz")
                files = {k: files[k] for k in files.keys() - targets.keys()}
                np.savez(f"test/{self.filename}.npz", **{**files, **targets})
            except Exception:
                np.savez(f"test/{self.filename}.npz", **files)
        targets = np.load(f"test/{self.filename}.npz")
        assert np.all(np.equal(estimate, targets[f"{name}estimate"]))
コード例 #5
0
def save_test_embeddings(key, experiment, save_path, n_samples_per_batch=64):

    # Load the full datset
    exp, sampler, encoder, decoder = experiment
    n_train, n_test, n_validation = exp.split_shapes
    data_loader = exp.data_loader

    # Compute our embeddings
    x, y = data_loader((n_test + n_validation, ),
                       start=0,
                       split='tpv',
                       return_labels=True,
                       onehot=False)
    z, _ = batched_evaluate(key, encoder, x, n_samples_per_batch)

    # Compute the UMAP embeddings
    u = umap.UMAP(random_state=0).fit_transform(z, y=y)

    z, y = np.array(z), np.array(y)
    np.savez(save_path, z=z, y=y, u=u)
コード例 #6
0
ファイル: run_fits.py プロジェクト: dimarkov/pybefit
def main(m, seed):
    rng_key = random.PRNGKey(seed)
    cutoff_up = 800
    cutoff_down = 100
    
    if m <= 10:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m)
    else:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10)
    
    model = generative_model

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
    seq = (seq['beliefs'][0][cutoff_down:cutoff_up], 
           seq['beliefs'][1][cutoff_down:cutoff_up])
    
    rng_key, _rng_key = random.split(rng_key)
    mcmc.run(
        _rng_key, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool), 
        extra_fields=('potential_energy',)
    )
    
    samples = mcmc.get_samples()    
    waic = log_pred_density(
        model, 
        samples, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool)
    )['waic']

    jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic)

    print(mcmc.get_extra_fields()['potential_energy'].mean())
コード例 #7
0
def main():
    args = get_arguments()
    project = args.project
    n_qubits, g, h = args.n_qubits, 2, 0
    filters = dict(n_qubits=n_qubits, g=g, h=h)
    if args.n_layers_list is not None:
        print(f'Add [n_layers] filter: {args.n_layers_list}')
        filters['n_layers'] = {'$in': args.n_layers_list}
    if args.lr is not None:
        print(f'Add [lr] filter: {args.lr}')
        filters['lr'] = args.lr
    resdir = Path('results_hessian')
    opt_circuits = download_circuits(project, filters, resdir)

    ham_matrix = qnnops.ising_hamiltonian(n_qubits, g, h)

    print('Hessian spectrum')
    hess_fns = {}
    for cfg, fpath in opt_circuits:
        print(f'| computing hessian of {fpath}')
        params = jnp.load(fpath)
        circuit_name = f'Q{n_qubits}-L{cfg["n_layers"]}-R{cfg["rot_axis"]}'
        if circuit_name not in hess_fns:
            loss_fn = get_loss_fn(
                ham_matrix, n_qubits, cfg['n_layers'], cfg['rot_axis'])
            hess_fns[circuit_name] = jax.hessian(loss_fn)
        _, hess_eigvals = compute_hessian_eigenvalues(hess_fns[circuit_name], params)

        name = get_normalized_name(cfg)
        jnp.savez(
            resdir / f'{name}_all.npz',
            params=params,
            ham_matrix=ham_matrix,
            hess_spectrum=hess_eigvals,
        )
        print(f'| ...plotting hessian spectrum and histogram')
        plot_spectrum(hess_eigvals, resdir / f'{name}_spectrum.pdf')
        plot_histogram(hess_eigvals, resdir / f'{name}_histogram.pdf')
コード例 #8
0
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    num_train: int,
) -> None:

    root = pathlib.Path("./data/seasonal")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    x_pred = posterior_predictive["x"]

    x_pred_trn = x_pred[:, :num_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    t_train = np.arange(num_train)

    x_pred_tst = x_pred[:, num_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)
    num_test = x_pred_tst.shape[1]
    t_test = np.arange(num_train, num_train + num_test)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(x.ravel(), label="ground truth", color=colors[0])

    plt.plot(t_train,
             x_pred_trn.mean(0)[:, 0],
             label="prediction",
             color=colors[1])
    plt.fill_between(t_train,
                     x_hpdi_trn[0, :, 0, 0],
                     x_hpdi_trn[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[1])

    plt.plot(t_test,
             x_pred_tst.mean(0)[:, 0],
             label="forecast",
             color=colors[2])
    plt.fill_between(t_test,
                     x_hpdi_tst[0, :, 0, 0],
                     x_hpdi_tst[1, :, 0, 0],
                     alpha=0.3,
                     color=colors[2])

    plt.ylim(x.min() - 0.5, x.max() + 0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(root / "prediction.png")
    plt.close()
コード例 #9
0
def main() -> None:

    df = load_dataset()
    test_index = 80
    test_len = len(df) - test_index
    y_train = jnp.array(df.loc[:test_index, "value"], dtype=jnp.float32)

    # Inference
    kernel = NUTS(sgt)
    mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1)
    mcmc.run(random.PRNGKey(0), y_train, seasonality=38)
    mcmc.print_summary()
    posterior_samples = mcmc.get_samples()

    # Prediction
    predictive = Predictive(sgt,
                            posterior_samples,
                            return_sites=["y_forecast"])
    posterior_predictive = predictive(random.PRNGKey(1),
                                      y_train,
                                      seasonality=38,
                                      future=test_len)

    root = pathlib.Path("./data/time_series")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)
    plot_results(
        df["time"].values,
        df["value"].values,
        posterior_samples,
        posterior_predictive,
        test_index,
        root,
    )
コード例 #10
0
def _save_results(
    x: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
) -> None:

    root = pathlib.Path("./data/kalman")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    len_train = x.shape[0]

    x_pred_trn = posterior_predictive["x"][:, :len_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    x_pred_tst = posterior_predictive["x"][:, len_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)

    len_test = x_pred_tst.shape[1]

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    plt.figure(figsize=(12, 6))
    plt.plot(x.ravel(), label="ground truth", color=colors[0])

    t_train = np.arange(len_train)
    plt.plot(t_train,
             x_pred_trn.mean(0).ravel(),
             label="prediction",
             color=colors[1])
    plt.fill_between(t_train,
                     x_hpdi_trn[0].ravel(),
                     x_hpdi_trn[1].ravel(),
                     alpha=0.3,
                     color=colors[1])

    t_test = np.arange(len_train, len_train + len_test)
    plt.plot(t_test,
             x_pred_tst.mean(0).ravel(),
             label="forecast",
             color=colors[2])
    plt.fill_between(t_test,
                     x_hpdi_tst[0].ravel(),
                     x_hpdi_tst[1].ravel(),
                     alpha=0.3,
                     color=colors[2])

    plt.legend()
    plt.tight_layout()
    plt.savefig(root / "kalman.png")
    plt.close()
コード例 #11
0
def main(args: argparse.Namespace) -> None:

    model = model_dict[args.model]

    _, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False)
    lengths, sequences = fetch()

    # Remove never used data dimension to reduce computation time
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    sequences = sequences[..., present_notes]
    batch, seq_len, data_dim = sequences.shape

    rng_key = random.PRNGKey(0)
    rng_key, rng_key_prior, rng_key_pred = random.split(rng_key, 3)

    predictive = infer.Predictive(model, num_samples=10)
    prior_samples = predictive(rng_key_prior,
                               batch=batch,
                               seq_len=seq_len,
                               data_dim=data_dim,
                               future_steps=20)

    kernel = infer.NUTS(model)
    mcmc = infer.MCMC(kernel, args.num_warmup, args.num_samples,
                      args.num_chains)
    mcmc.run(rng_key, sequences, lengths)
    posterior_samples = mcmc.get_samples()

    predictive = infer.Predictive(model, posterior_samples)
    predictive_samples = predictive(rng_key_pred,
                                    sequences,
                                    lengths,
                                    future_steps=10)

    path = pathlib.Path("./data/hmm_enum")
    path.mkdir(exist_ok=True)

    jnp.savez(path / "prior_samples.npz", **prior_samples)
    jnp.savez(path / "posterior_samples.npz", **posterior_samples)
    jnp.savez(path / "predictive_samples.npz", **predictive_samples)
コード例 #12
0
def save_params(path: Union[str, pathlib.Path],
                params: Dict[str, jnp.ndarray]) -> None:

    jnp.savez(path, **params)
コード例 #13
0
ファイル: run_fits_single.py プロジェクト: dimarkov/pybefit
def main(m, seed, device, dynamic_gamma, dynamic_preference):
    import jax.numpy as jnp
    from numpyro.infer import MCMC, NUTS
    from jax import random, nn, vmap, jit, devices, device_put

    # import utility functions for model inversion and belief estimation
    from utils import estimate_beliefs, single_model, log_pred_density

    # import data loader
    from stats import load_data

    outcomes_data, responses_data, mask_data, ns, _, _ = load_data()

    mask_data = jnp.array(mask_data)
    responses_data = jnp.array(responses_data).astype(jnp.int32)
    outcomes_data = jnp.array(outcomes_data).astype(jnp.int32)

    rng_key = random.PRNGKey(seed)
    cutoff_up = 1000
    cutoff_down = 400

    print(m, seed, dynamic_gamma, dynamic_preference)
    
    if m <= 10:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, device, mask=mask_data, nu_max=m)
    else:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, device, mask=mask_data, nu_max=10, nu_min=m-10)
    
    model = lambda *args: single_model(*args, dynamic_gamma=dynamic_gamma, dynamic_preference=dynamic_preference)

    # init preferences
    c0 = jnp.sum(nn.one_hot(outcomes_data[:cutoff_down], 4) * jnp.expand_dims(mask_data[:cutoff_down], -1), 0)

    if dynamic_gamma:
        num_warmup = 500
        num_samples = 500
        num_chains = 2
    else:
        num_warmup = 100
        num_samples = 100
        num_chains = 10

    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False)

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )
    
        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy

    seq = (
        seq['beliefs'][0][cutoff_down:cutoff_up], 
        seq['beliefs'][1][cutoff_down:cutoff_up], 
        outcomes_data[cutoff_down:cutoff_up],
        c0
    )

    y = responses_data[cutoff_down:cutoff_up]
    mask = mask_data[cutoff_down:cutoff_up].astype(bool)

    n = mask.shape[-1]
    rng_keys = random.split(rng_key, n)
    samples, potential_energy = jit(vmap(inference, in_axes=((1, 1, 1, 0), 1, 1, 0)))(seq, y, mask, rng_keys)

    waic = vmap(lambda *args: log_pred_density(model, *args), in_axes=(0, (1, 1, 1, 0), 1, 1))

    waic, log_likelihood = waic(samples, seq, y, mask)

    print('waic', waic.mean())

    jnp.savez('fit_data/fit_waic_sample_minf{}_gamma{}_pref{}_long.npz'.format(m, int(dynamic_gamma), int(dynamic_preference)), samples=samples, waic=waic, log_likelihood=log_likelihood)
コード例 #14
0
def save_history(filename, history, upload_to_wandb=True):
    """ Save jax array zip. """
    filepath = str(get_result_path(filename))
    jnp.savez(filepath, **history)
    if upload_to_wandb:
        safe_wandb_save(filepath)
コード例 #15
0
    def combined_running_test(self,
                              data,
                              kwargs,
                              fit_kwargs,
                              state=False,
                              validate=True,
                              fit=False,
                              none_first=True,
                              implemented=True,
                              aggregated=False):
        _kwargs = copy.deepcopy(kwargs)
        _kwargs = self.preload(_kwargs, state=state, validate=validate)

        imnn = self.imnn(**_kwargs)

        with pytest.raises(ValueError) as info:
            imnn.get_estimate(data)
        assert info.match(
            re.escape(
                "Fisher information has not yet been calculated. Please " +
                "run `imnn.set_F_statistics({w}, {key}, {validate}) " +
                "with `w = imnn.final_w`, `w = imnn.best_w`, " +
                "`w = imnn.inital_w` or otherwise, `validate = True` " +
                "should be set if not simulating on the fly."))

        time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        name = self.set_name(state, validate, fit, _kwargs["n_d"])

        if fit:
            _fit_kwargs = copy.deepcopy(fit_kwargs)
            λ = _fit_kwargs.pop("λ")
            ϵ = _fit_kwargs.pop("ϵ")

            if not implemented:
                with pytest.raises(ValueError) as info:
                    imnn.fit(λ, ϵ, **_fit_kwargs)
                assert info.match("`get_summaries` not implemented")
            else:
                if none_first:
                    _fit_kwargs["print_rate"] = None
                    string = ("Cannot run IMNN with progress bar after " +
                              "running without progress bar. Either set " +
                              "`print_rate` to None or reinitialise the IMNN.")
                else:
                    _fit_kwargs["print_rate"] = 100
                    string = ("Cannot run IMNN without progress bar after " +
                              "running with progress bar. Either set " +
                              "`print_rate` to an int or reinitialise the " +
                              "IMNN.")
                imnn.fit(λ, ϵ, **_fit_kwargs)
                if none_first:
                    _fit_kwargs["print_rate"] = 100
                else:
                    _fit_kwargs["print_rate"] = None
        else:
            if not implemented:
                with pytest.raises(ValueError) as info:
                    imnn.set_F_statistics(key=self.stats_key)
                assert info.match("`get_summaries` not implemented")
            else:
                imnn.set_F_statistics(key=self.stats_key)

        if implemented:
            estimates = []
            for element in data:
                estimates.append(imnn.get_estimate(element))

            if self.save:
                files = {
                    f"{name}estimates": estimates,
                    f"{name}F": imnn.F,
                    f"{name}C": imnn.C,
                    f"{name}invC": imnn.invC,
                    f"{name}dμ_dθ": imnn.dμ_dθ,
                    f"{name}μ": imnn.μ,
                    f"{name}invF": imnn.invF
                }
                try:
                    targets = np.load(f"test/{self.filename}.npz")
                    files = {
                        k: files[k]
                        for k in files.keys() - targets.keys()
                    }
                    np.savez(f"test/{self.filename}.npz",
                             **{
                                 **files,
                                 **targets
                             },
                             allow_pickle=True)
                except Exception:
                    np.savez(f"test/{self.filename}.npz",
                             **files,
                             allow_pickle=True)
            targets = np.load(f"test/{self.filename}.npz", allow_pickle=True)
            for i, estimate in enumerate(estimates):
                assert np.all(
                    np.equal(estimate, targets[f"{name}estimates"][i]))
            assert np.all(np.equal(imnn.F, targets[f"{name}F"]))
            assert np.all(np.equal(imnn.C, targets[f"{name}C"]))
            assert np.all(np.equal(imnn.invC, targets[f"{name}invC"]))
            assert np.all(np.equal(imnn.dμ_dθ, targets[f"{name}dμ_dθ"]))
            assert np.all(np.equal(imnn.μ, targets[f"{name}μ"]))
            assert np.all(np.equal(imnn.invF, targets[f"{name}invF"]))

        imnn.plot(expected_detF=50,
                  filename=f"test/figures/{self.filename}/{name}_{time}.pdf")
        plt.close("all")
        assert pathlib.Path(
            f"test/figures/{self.filename}/{name}_{time}.pdf").is_file()

        if fit and implemented and (not aggregated):
            with pytest.raises(ValueError) as info:
                imnn.fit(λ, ϵ, **_fit_kwargs)
            assert info.match(string)
コード例 #16
0
    def fit(self,
            kwargs,
            fit_kwargs,
            state=False,
            validate=True,
            set=True,
            none_first=True):
        _kwargs = copy.deepcopy(kwargs)
        _kwargs = self.preload(_kwargs, state=state, validate=validate)
        _fit_kwargs = copy.deepcopy(fit_kwargs)
        λ = _fit_kwargs.pop("λ")
        ϵ = _fit_kwargs.pop("ϵ")

        imnn = self.imnn(**_kwargs)
        if not set:
            with pytest.raises(ValueError) as info:
                self.imnn(**_kwargs).fit(λ, ϵ, **_fit_kwargs)
            assert info.match("`get_summaries` not implemented")
            return

        if none_first:
            _fit_kwargs["print_rate"] = None
            string = ("Cannot run IMNN with progress bar after running " +
                      "without progress bar. Either set `print_rate` to " +
                      "None or reinitialise the IMNN.")
        else:
            _fit_kwargs["print_rate"] = 100
            string = ("Cannot run IMNN without progress bar after running " +
                      "with progress bar. Either set `print_rate` to an int " +
                      "or reinitialise the IMNN.")

        imnn.fit(λ, ϵ, **_fit_kwargs)

        name = self.set_name(state, validate, False, _kwargs["n_d"])
        if self.save:
            files = {
                f"{name}F": imnn.F,
                f"{name}C": imnn.C,
                f"{name}invC": imnn.invC,
                f"{name}dμ_dθ": imnn.dμ_dθ,
                f"{name}μ": imnn.μ,
                f"{name}invF": imnn.invF,
            }
            try:
                targets = np.load(f"test/{self.filename}.npz")
                files = {k: files[k] for k in files.keys() - targets.keys()}
                np.savez(f"test/{self.filename}.npz", **{**files, **targets})
            except Exception:
                np.savez(f"test/{self.filename}.npz", **files)
        targets = np.load(f"test/{self.filename}.npz")
        assert np.all(np.equal(imnn.F, targets[f"{name}F"]))
        assert np.all(np.equal(imnn.C, targets[f"{name}C"]))
        assert np.all(np.equal(imnn.invC, targets[f"{name}invC"]))
        assert np.all(np.equal(imnn.dμ_dθ, targets[f"{name}dμ_dθ"]))
        assert np.all(np.equal(imnn.μ, targets[f"{name}μ"]))
        assert np.all(np.equal(imnn.invF, targets[f"{name}invF"]))

        if none_first:
            _fit_kwargs["print_rate"] = 100
        else:
            _fit_kwargs["print_rate"] = None

        with pytest.raises(ValueError) as info:
            imnn(**_kwargs).fit(λ, ϵ, **_fit_kwargs)
        assert info.match(string)
        return
コード例 #17
0
ファイル: run_fits_mixture.py プロジェクト: dimarkov/pybefit
def main(seed, device, dynamic_gamma, dynamic_preference, mc_type):

    import jax.numpy as jnp
    from numpyro.infer import MCMC, NUTS
    from jax import random, nn, devices, device_put, vmap, jit

    # import utility functions for model inversion and belief estimation
    from utils import estimate_beliefs, mixture_model

    # import data loader
    from stats import load_data

    outcomes_data, responses_data, mask_data, ns, _, _ = load_data()

    print(seed, device, dynamic_gamma, dynamic_preference)

    model = lambda *args: mixture_model(*args, dynamic_gamma=dynamic_gamma, dynamic_preference=dynamic_preference)

    m_data = jnp.array(mask_data)
    r_data = jnp.array(responses_data).astype(jnp.int32)
    o_data = jnp.array(outcomes_data).astype(jnp.int32)

    rng_key = random.PRNGKey(seed)
    cutoff_up = 1000
    cutoff_down = 400

    priors = []
    params = []

    if mc_type == 'nu_max':
        M_rng = list(range(1, 11))  # model comparison for regular condition
    else:
        M_rng = [1,] + list(range(11, 20))  # model comparison for irregular condition

    for M in M_rng:
        if M <= 10:
            seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=M)
        else:
            seq, _ = estimate_beliefs(o_data, r_data, device, mask=m_data, nu_max=10, nu_min=M-10)
        
        priors.append(seq['beliefs'][0][cutoff_down:cutoff_up])
        params.append(seq['beliefs'][1][cutoff_down:cutoff_up])
    
    device = devices(device)[0]

    # init preferences
    c0 = jnp.sum(nn.one_hot(outcomes_data[:cutoff_down], 4) * jnp.expand_dims(mask_data[:cutoff_down], -1), 0)

    if dynamic_gamma:
        num_warmup = 1000
        num_samples = 1000
        num_chains = 1
    else:
        num_warmup = 200
        num_samples = 200
        num_chains = 5

    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(
            nuts_kernel, 
            num_warmup=num_warmup, 
            num_samples=num_samples, 
            num_chains=num_chains,
            chain_method="vectorized",
            progress_bar=False
        )

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )

        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy

    seqs = device_put(
                    (
                        jnp.stack(priors, 0), 
                        jnp.stack(params, 0), 
                        o_data[cutoff_down:cutoff_up], 
                        c0
                    ), 
                device)

    y = device_put(r_data[cutoff_down:cutoff_up], device)
    mask = device_put(m_data[cutoff_down:cutoff_up].astype(bool), device)

    n = mask.shape[-1]
    rng_keys = random.split(rng_key, n)

    samples, potential_energy = jit(vmap(inference, in_axes=((2, 2, 1, 0), 1, 1, 0)))(seqs, y, mask, rng_keys)

    print('potential_energy', potential_energy)     

    jnp.savez('fit_data/fit_sample_mixture_gamma{}_pref{}_{}.npz'.format(int(dynamic_gamma), int(dynamic_preference), mc_type), samples=samples)
コード例 #18
0
def savez(file, *args, **kwds):
  args = [_remove_jaxarray(a) for a in args]
  kwds = {k: _remove_jaxarray(v) for k, v in kwds.items()}
  jnp.savez(file, *args, **kwds)
コード例 #19
0
def _save_results(
    x: jnp.ndarray,
    betas: jnp.ndarray,
    prior_samples: Dict[str, jnp.ndarray],
    posterior_samples: Dict[str, jnp.ndarray],
    posterior_predictive: Dict[str, jnp.ndarray],
    num_train: int,
) -> None:

    root = pathlib.Path("./data/dlm")
    root.mkdir(exist_ok=True)

    jnp.savez(root / "piror_samples.npz", **prior_samples)
    jnp.savez(root / "posterior_samples.npz", **posterior_samples)
    jnp.savez(root / "posterior_predictive.npz", **posterior_predictive)

    x_pred = posterior_predictive["x"]

    x_pred_trn = x_pred[:, :num_train]
    x_hpdi_trn = diagnostics.hpdi(x_pred_trn)
    t_train = np.arange(num_train)

    x_pred_tst = x_pred[:, num_train:]
    x_hpdi_tst = diagnostics.hpdi(x_pred_tst)
    num_test = x_pred_tst.shape[1]
    t_test = np.arange(num_train, num_train + num_test)

    t_axis = np.arange(num_train + num_test)
    w_pred = posterior_predictive["weight"]
    w_hpdi = diagnostics.hpdi(w_pred)

    prop_cycle = plt.rcParams["axes.prop_cycle"]
    colors = prop_cycle.by_key()["color"]

    beta_dim = betas.shape[-1]
    plt.figure(figsize=(8, 12))
    for i in range(beta_dim + 1):
        plt.subplot(beta_dim + 1, 1, i + 1)
        if i == 0:
            plt.plot(x[:, 0], label="ground truth", color=colors[0])
            plt.plot(t_train,
                     x_pred_trn.mean(0)[:, 0],
                     label="prediction",
                     color=colors[1])
            plt.fill_between(t_train,
                             x_hpdi_trn[0, :, 0, 0],
                             x_hpdi_trn[1, :, 0, 0],
                             alpha=0.3,
                             color=colors[1])
            plt.plot(t_test,
                     x_pred_tst.mean(0)[:, 0],
                     label="forecast",
                     color=colors[2])
            plt.fill_between(t_test,
                             x_hpdi_tst[0, :, 0, 0],
                             x_hpdi_tst[1, :, 0, 0],
                             alpha=0.3,
                             color=colors[2])
            plt.title("ground truth", fontsize=16)
        else:
            plt.plot(betas[:, 0, i - 1], label="ground truth", color=colors[0])
            plt.plot(t_axis,
                     w_pred.mean(0)[:, i - 1],
                     label="prediction",
                     color=colors[1])
            plt.fill_between(t_axis,
                             w_hpdi[0, :, i - 1, 0],
                             w_hpdi[1, :, i - 1, 0],
                             alpha=0.3,
                             color=colors[1])
            plt.title(f"coef_{i - 1}", fontsize=16)
        plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig(root / "prediction.png")
    plt.close()
コード例 #20
0
def test_and_plot_collocation_svgp_quad():
    gp = FakeSVGPQuad()
    metric = FakeSVGPMetricQuad()
    solver = FakeCollocationSVGPQuad()
    params = {
        # 'axes.labelsize': 30,
        # 'font.size': 30,
        # 'legend.fontsize': 20,
        # 'xtick.labelsize': 30,
        # 'ytick.labelsize': 30,
        "text.usetex": True,
    }
    plt.rcParams.update(params)

    dir_name = create_save_dir()
    fig, ax = plot_mixing_prob_and_start_end(gp, solver, solver.state_guesses)
    # plot_traj(fig, ax, solver.state_guesses)
    save_name = dir_name + "mixing_prob_2d_init_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    fig, axs = plot_svgp_and_start_end(gp, solver)
    # plot_traj(fig, axs, solver.state_guesses)
    save_name = dir_name + "svgp_2d_init_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    fig, axs = plot_epistemic_var_vs_time(
        gp, solver, traj_init=solver.state_guesses, traj_opt=None
    )
    save_name = dir_name + "epistemic_var_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)
    plot_aleatoric_var_vs_time(
        gp, solver, traj_init=solver.state_guesses, traj_opt=None
    )
    save_name = dir_name + "aleatoric_var_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    fig, ax = plot_svgp_metric_trace_and_start_end(metric, solver)
    plot_traj(fig, ax, solver.state_guesses)
    save_name = dir_name + "metric_trace_2d_init_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    fig, axs = plot_svgp_jacobian_mean(
        gp, solver, traj_opt=solver.state_guesses
    )
    fig, axs = plot_svgp_jacobian_var(
        gp, solver, traj_opt=solver.state_guesses
    )

    fig, axs = plot_svgp_metric_and_start_end(
        metric, solver, traj_opt=solver.state_guesses
    )

    # # prob vs time
    # plot_mixing_prob_over_time(gp,
    #                            solver,
    #                            traj_init=solver.state_guesses,
    #                            traj_opt=solver.state_guesses)
    # save_name = dir_name + "mixing_prob_vs_time.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)
    # plot_metirc_trace_over_time(metric,
    #                             solver,
    #                             traj_init=solver.state_guesses,
    #                             traj_opt=solver.state_guesses)
    # save_name = dir_name + "metric_trace_vs_time.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)

    # fig, axs = plot_3d_mean_and_var(gp, solver)
    # plot_3d_traj_mean_and_var(fig, axs, gp, traj=solver.state_guesses)
    # save_name = dir_name + "init_traj_mean_and_var.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)
    # fig, ax = plot_3d_metric_trace(metric, solver)
    # plot_3d_traj_metric_trace(fig, ax, metric, traj=solver.state_guesses)
    # save_name = dir_name + "init_traj_metric_trace.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)
    # fig, ax = plot_3d_mixing_prob(gp, solver)
    # plot_3d_traj_mixing_prob(fig, ax, gp, traj=solver.state_guesses)
    # save_name = dir_name + "init_traj_mixing_prob.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)

    plt.show()
    geodesic_traj = test_collocation_svgp_quad()

    # fig, ax = plot_3d_metric_trace(metric, solver)
    # plot_3d_traj_metric_trace(fig, ax, metric, traj=geodesic_traj)
    # save_name = dir_name + "opt_traj_metric_trace.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)
    # fig, axs = plot_3d_mean_and_var(gp, solver)
    # plot_3d_traj_mean_and_var(fig, axs, gp, geodesic_traj)
    # save_name = dir_name + "opt_traj_mean_and_var.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)
    # fig, ax = plot_3d_mixing_prob(gp, solver)
    # plot_3d_traj_mixing_prob(fig, ax, gp, traj=geodesic_traj)
    # save_name = dir_name + "opt_traj_mixing_prob.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)

    # fig, axs = plot_svgp_metric_trace_and_start_end(metric, solver)
    # fig, axs = plot_traj(fig, axs, geodesic_traj)
    # fig, axs = plot_svgp_and_start_end(gp, solver)
    # fig, axs = plot_traj(fig, axs, geodesic_traj)
    # plt.show()

    # plot init and opt trajectories over svgp
    fig, axs = plot_svgp_and_start_end(gp, solver, traj_opt=geodesic_traj)
    save_name = dir_name + "svgp_2d_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    # init and opt trajectories over mixing probability
    fig, ax = plot_mixing_prob_and_start_end(
        gp, solver, traj_opt=geodesic_traj
    )
    save_name = dir_name + "mixing_prob_2d_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    # init and opt trajectories over metric trace
    fig, ax = plot_svgp_metric_trace_and_start_end(
        metric, solver, traj_opt=geodesic_traj
    )
    save_name = dir_name + "metric_trace_2d_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    # prob vs time
    plot_mixing_prob_over_time(
        gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj
    )
    save_name = dir_name + "mixing_prob_vs_time.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)
    # plot metric trace vs time
    plot_metirc_trace_over_time(
        metric, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj
    )
    save_name = dir_name + "metric_trace_vs_time.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    # plot epistemic uncertainty vs time
    fig, axs = plot_epistemic_var_vs_time(
        gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj
    )
    save_name = dir_name + "epistemic_var_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)
    plot_aleatoric_var_vs_time(
        gp, solver, traj_init=solver.state_guesses, traj_opt=geodesic_traj
    )
    save_name = dir_name + "aleatoric_var_traj.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    save_name = dir_name + "geodesic_traj.npz"
    np.savez(save_name, geodesic_traj)

    # fig, ax = plot_3d_mixing_prob(gp, solver)
    # plot_3d_traj_mixing_prob_init_and_opt(fig,
    #                                       ax,
    #                                       gp,
    #                                       traj_init=solver.state_guesses,
    #                                       traj_opt=geodesic_traj)
    # save_name = dir_name + "traj_mixing_prob.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)

    # fig, ax = plot_3d_metric_trace(metric, solver)
    # plot_3d_traj_metric_trace_init_and_opt(fig,
    #                                        ax,
    #                                        metric,
    #                                        traj_init=solver.state_guesses,
    #                                        traj_opt=geodesic_traj)
    # save_name = dir_name + "traj_metric_trace.pdf"
    # plt.savefig(save_name, transparent=True, bbox_inches='tight', pad_inches=0)

    plt.show()
コード例 #21
0
ファイル: run_sims.py プロジェクト: dimarkov/pybefit
            params.append(seq_sim[1].copy())

        c0 = np.sum(
            nn.one_hot(outcomes_sim[cutoff:], 4).copy().astype(float), 0)

        # fit simulated data
        seqs = device_put((np.stack(priors, 0), np.stack(
            params, 0), outcomes_sim[cutoff:].copy(), c0), device)

        y = device_put(responses_sim[cutoff:].copy(), device)
        mask = jnp.ones_like(y).astype(bool)

        rng_keys = random.split(rng_key, N)

        samples = jit(vmap(inference,
                           in_axes=((2, 2, 1, 0), 1, 1, 0)))(seqs, y, mask,
                                                             rng_keys)

        post_smpl[m_true] = samples

        jnp.savez('fit_sims/tmp_sims_m{}.npz'.format(m_true), samples=samples)

        del seq_sim, samples, agent

    # save posterior estimates
    jnp.savez('fit_sims/sims_mcomp_numin_P-{}-{}-{}-{}.npz'.format(*P_o),
              samples=post_smpl)

    # delete tmp files
    for m_true in m_rng:
        os.remove('fit_sims/tmp_sims_m{}.npz'.format(m_true))