コード例 #1
0
def test_diverging(kernel_cls, adapt_step_size):
    data = random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        loc = numpyro.sample('loc', dist.Normal(0., 1.))
        numpyro.sample('obs', dist.Normal(loc, 1), obs=data)

    kernel = kernel_cls(model,
                        step_size=10.,
                        adapt_step_size=adapt_step_size,
                        adapt_mass_matrix=False)
    num_warmup = num_samples = 1000
    mcmc = MCMC(kernel, num_warmup, num_samples)
    mcmc.warmup(random.PRNGKey(1),
                data,
                extra_fields=['diverging'],
                collect_warmup=True)
    warmup_divergences = mcmc.get_extra_fields()['diverging'].sum()
    mcmc.run(random.PRNGKey(2), data, extra_fields=['diverging'])
    num_divergences = warmup_divergences + mcmc.get_extra_fields(
    )['diverging'].sum()
    if adapt_step_size:
        assert num_divergences <= num_warmup
    else:
        assert_allclose(num_divergences, num_warmup + num_samples)
コード例 #2
0
    def _sample(current_state, seed):
        step_size = jax.tree_map(jax.numpy.ones_like, init_state)
        nuts_kernel = NUTS(
            potential_fn=lambda x: -logp_fn_jax(*x),
            # model=model,
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        pmap_numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            chain_method=chain_method,
            progress_bar=progress_bar,
        )

        pmap_numpyro.run(seed,
                         init_params=current_state,
                         extra_fields=("num_steps", ))
        samples = pmap_numpyro.get_samples(group_by_chain=True)
        leapfrogs_taken = pmap_numpyro.get_extra_fields(
            group_by_chain=True)["num_steps"]
        return samples, leapfrogs_taken
コード例 #3
0
def main() -> None:

    # Data
    num = 8
    y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
    sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

    # Random key
    rng_key = random.PRNGKey(0)

    # Inference
    nuts_kernel = NUTS(model_noncentered)
    mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
    mcmc.run(rng_key, num, sigma, y=y, extra_fields=("potential_energy", ))
    print(mcmc.print_summary())

    # Extra
    pe = mcmc.get_extra_fields()["potential_energy"]
    print(f"Expected log joint density: {np.mean(-pe):.2f}")

    # Prediction
    predictive = Predictive(model_pred, num_samples=100)
    samples = predictive(random.PRNGKey(1))
    print("prior", np.mean(samples["obs"]))

    predictive = Predictive(model_pred, mcmc.get_samples())
    samples = predictive(random.PRNGKey(1))
    print("posterior", np.mean(samples["obs"]))
コード例 #4
0
def test_estimate_likelihood(kernel_cls):
    data_key, tr_key, sub_key, rng_key = random.split(random.PRNGKey(0), 4)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = 0.1
    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (10_000,)
    )
    n, _ = data.shape
    num_warmup = 200
    num_samples = 200
    num_blocks = 20

    def model(data):
        mean = numpyro.sample(
            "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))
        )
        with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx:
            numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])

    proxy_fn = HMCECS.taylor_proxy({"mean": ref_params})
    kernel = HMCECS(kernel_cls(model), proxy=proxy_fn, num_blocks=num_blocks)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)

    mcmc.run(random.PRNGKey(0), data, extra_fields=["hmc_state.potential_energy"])

    pes = mcmc.get_extra_fields()["hmc_state.potential_energy"]
    samples = mcmc.get_samples()
    pes_full = vmap(
        lambda sample: log_density(
            model, (data,), {}, {**sample, **{"N": jnp.arange(n)}}
        )[0]
    )(samples)

    assert jnp.var(jnp.exp(-pes - pes_full)) < 1.0
コード例 #5
0
def test_extra_fields():
    def model():
        numpyro.sample('x', dist.Normal(0, 1), sample_shape=(5,))

    mcmc = MCMC(NUTS(model), 1000, 1000)
    mcmc.run(random.PRNGKey(0), extra_fields=('num_steps', 'adapt_state.step_size'))
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples['x'].shape == (1, 1000, 5)
    stats = mcmc.get_extra_fields(group_by_chain=True)
    assert 'num_steps' in stats
    assert stats['num_steps'].shape == (1, 1000)
    assert 'adapt_state.step_size' in stats
    assert stats['adapt_state.step_size'].shape == (1, 1000)
コード例 #6
0
def test_extra_fields():
    def model():
        numpyro.sample("x", dist.Normal(0, 1), sample_shape=(5,))

    mcmc = MCMC(NUTS(model), 1000, 1000)
    mcmc.run(random.PRNGKey(0), extra_fields=("num_steps", "adapt_state.step_size"))
    samples = mcmc.get_samples(group_by_chain=True)
    assert samples["x"].shape == (1, 1000, 5)
    stats = mcmc.get_extra_fields(group_by_chain=True)
    assert "num_steps" in stats
    assert stats["num_steps"].shape == (1, 1000)
    assert "adapt_state.step_size" in stats
    assert stats["adapt_state.step_size"].shape == (1, 1000)
コード例 #7
0
ファイル: sampling_jax.py プロジェクト: AlexAndorra/pymc3
    def _sample(*inputs):

        if op.nshared > 0:
            current_state = inputs[:-op.nshared]
            shared_inputs = tuple(op.fgraph.inputs[-op.nshared:])
        else:
            current_state = inputs
            shared_inputs = ()

        def log_fn_wrap(x):
            res = logp_fn(*(
                x
                # We manually obtain the shared values and added them
                # as arguments to our compiled "inner" function
                + tuple(
                    v.get_value(borrow=True, return_internal_type=True)
                    for v in shared_inputs)))

            if isinstance(res, (list, tuple)):
                # This handles the new JAX backend, which always returns a tuple
                res = res[0]

            return -res

        nuts_kernel = NUTS(
            potential_fn=log_fn_wrap,
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        pmap_numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            chain_method="parallel",
            progress_bar=progress_bar,
        )

        pmap_numpyro.run(seed,
                         init_params=current_state,
                         extra_fields=("num_steps", ))
        samples = pmap_numpyro.get_samples(group_by_chain=True)
        leapfrogs_taken = pmap_numpyro.get_extra_fields(
            group_by_chain=True)["num_steps"]
        return tuple(samples) + (leapfrogs_taken, )
コード例 #8
0
ファイル: run_fits_single.py プロジェクト: dimarkov/pybefit
    def inference(belief_sequences, obs, mask, rng_key):
        nuts_kernel = NUTS(model, dense_mass=True)
        mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method="vectorized", progress_bar=False)

        mcmc.run(
            rng_key, 
            belief_sequences, 
            obs, 
            mask, 
            extra_fields=('potential_energy',)
        )
    
        samples = mcmc.get_samples()
        potential_energy = mcmc.get_extra_fields()['potential_energy'].mean()
        # mcmc.print_summary()

        return samples, potential_energy
コード例 #9
0
ファイル: run_fits.py プロジェクト: dimarkov/pybefit
def main(m, seed):
    rng_key = random.PRNGKey(seed)
    cutoff_up = 800
    cutoff_down = 100
    
    if m <= 10:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=m)
    else:
        seq, _ = estimate_beliefs(outcomes_data, responses_data, mask=mask_data, nu_max=10, nu_min=m-10)
    
    model = generative_model

    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
    seq = (seq['beliefs'][0][cutoff_down:cutoff_up], 
           seq['beliefs'][1][cutoff_down:cutoff_up])
    
    rng_key, _rng_key = random.split(rng_key)
    mcmc.run(
        _rng_key, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool), 
        extra_fields=('potential_energy',)
    )
    
    samples = mcmc.get_samples()    
    waic = log_pred_density(
        model, 
        samples, 
        seq, 
        y=responses_data[cutoff_down:cutoff_up], 
        mask=mask_data[cutoff_down:cutoff_up].astype(bool)
    )['waic']

    jnp.savez('fit_waic_sample/dyn_fit_waic_sample_minf{}.npz'.format(m), samples=samples, waic=waic)

    print(mcmc.get_extra_fields()['potential_energy'].mean())
コード例 #10
0
def benchmark_hmc(args, features, labels):
    rng_key = random.PRNGKey(1)
    start = time.time()
    # a MAP estimate at the following source
    # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
    ref_params = {
        "coefs":
        jnp.array([
            +2.03420663e00,
            -3.53567265e-02,
            -1.49223924e-01,
            -3.07049364e-01,
            -1.00028366e-01,
            -1.46827862e-01,
            -1.64167881e-01,
            -4.20344204e-01,
            +9.47479829e-02,
            -1.12681836e-02,
            +2.64442056e-01,
            -1.22087866e-01,
            -6.00568838e-02,
            -3.79419506e-01,
            -1.06668741e-01,
            -2.97053963e-01,
            -2.05253899e-01,
            -4.69537191e-02,
            -2.78072730e-02,
            -1.43250525e-01,
            -6.77954629e-02,
            -4.34899796e-03,
            +5.90927452e-02,
            +7.23133609e-02,
            +1.38526391e-02,
            -1.24497898e-01,
            -1.50733739e-02,
            -2.68872194e-02,
            -1.80925727e-02,
            +3.47936489e-02,
            +4.03552800e-02,
            -9.98773426e-03,
            +6.20188080e-02,
            +1.15002751e-01,
            +1.32145107e-01,
            +2.69109547e-01,
            +2.45785132e-01,
            +1.19035013e-01,
            -2.59744357e-02,
            +9.94279515e-04,
            +3.39266285e-02,
            -1.44057125e-02,
            -6.95222765e-02,
            -7.52013028e-02,
            +1.21171586e-01,
            +2.29205526e-02,
            +1.47308692e-01,
            -8.34354162e-02,
            -9.34122875e-02,
            -2.97472421e-02,
            -3.03937674e-01,
            -1.70958012e-01,
            -1.59496680e-01,
            -1.88516974e-01,
            -1.20889175e00,
        ])
    }
    if args.algo == "HMC":
        step_size = jnp.sqrt(0.5 / features.shape[0])
        trajectory_length = step_size * args.num_steps
        kernel = HMC(
            model,
            step_size=step_size,
            trajectory_length=trajectory_length,
            adapt_step_size=False,
            dense_mass=args.dense_mass,
        )
        subsample_size = None
    elif args.algo == "NUTS":
        kernel = NUTS(model, dense_mass=args.dense_mass)
        subsample_size = None
    elif args.algo == "HMCECS":
        subsample_size = 1000
        inner_kernel = NUTS(
            model,
            init_strategy=init_to_value(values=ref_params),
            dense_mass=args.dense_mass,
        )
        # note: if num_blocks=100, we'll update 10 index at each MCMC step
        # so it took 50000 MCMC steps to iterative the whole dataset
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(ref_params))
    elif args.algo == "SA":
        # NB: this kernel requires large num_warmup and num_samples
        # and running on GPU is much faster than on CPU
        kernel = SA(model,
                    adapt_state_size=1000,
                    init_strategy=init_to_value(values=ref_params))
        subsample_size = None
    elif args.algo == "FlowHMCECS":
        subsample_size = 1000
        guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
        svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
        svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
        params, losses = svi_result.params, svi_result.losses
        plt.plot(losses)
        plt.show()

        neutra = NeuTraReparam(guide, params)
        neutra_model = neutra.reparam(model)
        neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
        # no need to adapt mass matrix if the flow does a good job
        inner_kernel = NUTS(
            neutra_model,
            init_strategy=init_to_value(values=neutra_ref_params),
            adapt_mass_matrix=False,
        )
        kernel = HMCECS(inner_kernel,
                        num_blocks=100,
                        proxy=HMCECS.taylor_proxy(neutra_ref_params))
    else:
        raise ValueError(
            "Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
    mcmc = MCMC(kernel,
                num_warmup=args.num_warmup,
                num_samples=args.num_samples)
    mcmc.run(rng_key,
             features,
             labels,
             subsample_size,
             extra_fields=("accept_prob", ))
    print("Mean accept prob:",
          jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
    mcmc.print_summary(exclude_deterministic=False)
    print("\nMCMC elapsed time:", time.time() - start)
コード例 #11
0
def run_model(
    model_func,
    data,
    ep,
    num_samples=500,
    num_warmup=500,
    num_chains=4,
    target_accept=0.75,
    max_tree_depth=15,
    save_results=True,
    output_fname=None,
    model_kwargs=None,
    save_json=False,
    chain_method="parallel",
    heuristic_step_size=True,
):
    """
    Model run utility

    :param model_func: numpyro model
    :param data: PreprocessedData object
    :param ep: EpidemiologicalParameters object
    :param num_samples: number of samples
    :param num_warmup: number of warmup samples
    :param num_chains: number of chains
    :param target_accept: target accept
    :param max_tree_depth: maximum treedepth
    :param save_results: whether to save full results
    :param output_fname: output filename
    :param model_kwargs: model kwargs -- extra arguments for the model function
    :param save_json: whether to save json
    :param chain_method: Numpyro chain method to use
    :param heuristic_step_size: whether to find a heuristic step size
    :return: posterior_samples, warmup_samples, info_dict (dict with assorted diagnostics), Numpyro mcmc object
    """
    print(
        f"Running {num_chains} chains, {num_samples} per chain with {num_warmup} warmup steps"
    )
    nuts_kernel = NUTS(
        model_func,
        init_strategy=init_to_median,
        target_accept_prob=target_accept,
        max_tree_depth=max_tree_depth,
        find_heuristic_step_size=heuristic_step_size,
    )
    mcmc = MCMC(
        nuts_kernel,
        num_samples=num_samples,
        num_warmup=num_warmup,
        num_chains=num_chains,
        chain_method=chain_method,
    )
    rng_key = random.PRNGKey(0)

    # hmcstate = nuts_kernel.init(rng_key, 1, model_args=(data, ep))
    # nRVs = hmcstate.adapt_state.inverse_mass_matrix.size
    # inverse_mass_matrix = init_diag_inv_mass_mat * jnp.ones(nRVs)
    # mass_matrix_sqrt_inv = np.sqrt(inverse_mass_matrix)
    # mass_matrix_sqrt = 1./mass_matrix_sqrt_inv
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(inverse_mass_matrix=inverse_mass_matrix))
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt_inv=mass_matrix_sqrt_inv))
    # hmcstate = hmcstate._replace(adapt_state=hmcstate.adapt_state._replace(mass_matrix_sqrt=mass_matrix_sqrt))
    # mcmc.post_warmup_state = hmcstate

    info_dict = {
        "model_name": model_func.__name__,
    }

    start = time.time()
    if model_kwargs is None:
        model_kwargs = {}

    info_dict["model_kwargs"] = model_kwargs

    # also collect some extra information for better diagonstics!
    print(f"Warmup Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    mcmc.warmup(
        rng_key,
        data,
        ep,
        **model_kwargs,
        collect_warmup=True,
        extra_fields=["num_steps", "mean_accept_prob", "adapt_state"],
    )
    mcmc.get_extra_fields()["num_steps"].block_until_ready()

    info_dict["warmup"] = {}
    info_dict["warmup"]["num_steps"] = np.array(
        mcmc.get_extra_fields()["num_steps"]).tolist()
    info_dict["warmup"]["step_size"] = np.array(
        mcmc.get_extra_fields()["adapt_state"].step_size).tolist()
    info_dict["warmup"]["inverse_mass_matrix"] = {}

    all_mass_mats = jnp.array(
        jnp.array_split(
            mcmc.get_extra_fields()["adapt_state"].inverse_mass_matrix,
            num_chains,
            axis=0,
        ))

    print(all_mass_mats.shape)

    for i in range(num_chains):
        info_dict["warmup"]["inverse_mass_matrix"][
            f"chain_{i}"] = all_mass_mats[i, -1, :].tolist()

    info_dict["warmup"]["mean_accept_prob"] = np.array(
        mcmc.get_extra_fields()["mean_accept_prob"]).tolist()

    warmup_samples = mcmc.get_samples()

    print(f"Sample Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    mcmc.run(
        rng_key,
        data,
        ep,
        **model_kwargs,
        extra_fields=["num_steps", "mean_accept_prob", "adapt_state"],
    )

    posterior_samples = mcmc.get_samples()
    # if you don't block this, the timer won't quite work properly.
    posterior_samples[list(posterior_samples.keys())[0]].block_until_ready()
    print(f"Sample Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    end = time.time()
    time_per_sample = float(end - start) / num_samples
    divergences = int(mcmc.get_extra_fields()["diverging"].sum())

    info_dict["time_per_sample"] = time_per_sample
    info_dict["total_runtime"] = float(end - start)
    info_dict["divergences"] = divergences

    info_dict["sample"] = {}
    info_dict["sample"]["num_steps"] = np.array(
        mcmc.get_extra_fields()["num_steps"]).tolist()
    info_dict["sample"]["mean_accept_prob"] = np.array(
        mcmc.get_extra_fields()["mean_accept_prob"]).tolist()
    info_dict["sample"]["step_size"] = np.array(
        mcmc.get_extra_fields()["adapt_state"].step_size).tolist()

    print(f"Sampling {num_samples} samples per chain took {end - start:.2f}s")
    print(f"There were {divergences} divergences.")

    grouped_posterior_samples = mcmc.get_samples(True)

    all_ess = np.array([])
    for k in grouped_posterior_samples.keys():
        ess = numpyro.diagnostics.effective_sample_size(
            np.asarray(grouped_posterior_samples[k]))
        all_ess = np.append(all_ess, ess)

    print(f"{np.sum(np.isnan(all_ess))}  ESS were nan")
    all_ess = all_ess[np.logical_not(np.isnan(all_ess))]

    info_dict["ess"] = {
        "med": float(np.percentile(all_ess, 50)),
        "lower": float(np.percentile(all_ess, 2.5)),
        "upper": float(np.percentile(all_ess, 97.5)),
        "min": float(np.min(all_ess)),
        "max": float(np.max(all_ess)),
    }
    print(
        f"Mean ESS: {info_dict['ess']['med']:.2f} [{info_dict['ess']['lower']:.2f} ... {info_dict['ess']['upper']:.2f}]"
    )

    if num_chains > 1:
        all_rhat = np.array([])
        for k in grouped_posterior_samples.keys():
            rhat = numpyro.diagnostics.gelman_rubin(
                np.asarray(grouped_posterior_samples[k]))
            all_rhat = np.append(all_rhat, rhat)

        print(f"{np.sum(np.isnan(all_rhat))} Rhat were nan")
        all_rhat = all_rhat[np.logical_not(np.isnan(all_rhat))]

        info_dict["rhat"] = {
            "med": float(np.percentile(all_rhat, 50)),
            "upper": float(np.percentile(all_rhat, 97.5)),
            "lower": float(np.percentile(all_rhat, 2.5)),
            "min": float(np.max(all_rhat)),
            "max": float(np.min(all_rhat)),
        }

        print(
            f"Rhat: {info_dict['rhat']['med']:.2f} [{info_dict['rhat']['lower']:.2f} ... {info_dict['rhat']['upper']:.2f}]"
        )

    if save_results:
        print("Saving .netcdf")
        try:
            inf_data = az.from_numpyro(mcmc)

            if output_fname is None:
                output_fname = f'{model_func.__name__}-{datetime.now(tz=None).strftime("%d-%m;%H-%M-%S")}.netcdf'

            az.to_netcdf(inf_data, output_fname)

            json_fname = output_fname.replace(".netcdf", ".json")
            if save_json:
                print("Saving Json")
                with open(json_fname, "w") as f:
                    json.dump(info_dict, f, ensure_ascii=False, indent=4)

        except Exception as e:
            print(e)

    return posterior_samples, warmup_samples, info_dict, mcmc