Exemplo n.º 1
0
def test_nuts_sampling_runs():
    model_ndim = 1
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim)
    draws = 3
    tune = 1
    chains = 2
    cores = 1

    trace, stats = lmc.sample(logp_dlogp_func,
                              model_ndim,
                              draws,
                              tune,
                              step=step,
                              chains=chains,
                              cores=cores)

    assert trace.shape == (chains, draws, model_ndim)
    assert all([
        stats[name].shape == (chains, draws, model_ndim)
        for (name, _) in step.stats_dtypes[0].items()
    ])
    assert all([
        stats[name].dtype == expected_dtype
        for (name, expected_dtype) in step.stats_dtypes[0].items()
    ])
def test_multiprocessing_with_various_frameworks(framework):
    logp_dlogp_funcs = {
        "pytorch": torch_logp_dlogp_func,
        "jax": jax_logp_dlogp_func,
        "pymc3": pm_logp_dlogp_func,
    }

    logp_dlogp_func = logp_dlogp_funcs[framework]

    model_ndim = 3
    tune = 10
    draws = 10
    chains = 4
    cores = 4

    trace, stats = lmc.sample(
        logp_dlogp_func=logp_dlogp_func,
        model_ndim=model_ndim,
        tune=tune,
        draws=draws,
        chains=chains,
        cores=cores,
        progressbar=False,
    )

    assert trace.shape == (chains, draws, model_ndim)
Exemplo n.º 3
0
def test_reset_tuning():
    model_ndim = 1
    draws = 2
    tune = 50
    chains = 2
    start, step = lmc.init_nuts(logp_dlogp_func=logp_dlogp_func, model_ndim=1)
    cores = 1
    lmc.sample(
        logp_dlogp_func,
        model_ndim,
        draws=draws,
        tune=tune,
        chains=chains,
        step=step,
        start=start,
        cores=cores,
    )
    assert step.potential._n_samples == tune
    assert step.step_adapt._count == tune + 1
Exemplo n.º 4
0
def test_multiprocess_sampling_runs():
    size = 1
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
    draws = 1
    tune = 1
    chains = None
    cores = None
    trace, stats = lmc.sample(
        logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores
    )
Exemplo n.º 5
0
def test_nuts_sampling_runs():
    size = 1
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
    draws = 2
    tune = 1
    chains = 1
    cores = 1
    trace, stats = lmc.sample(
        logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores
    )
    assert trace.shape == (1, 2)
Exemplo n.º 6
0
def test_samples_not_all_same():
    model_ndim = 1
    draws = 50
    tune = 10
    chains = 1
    cores = 1
    trace, stats = lmc.sample(logp_dlogp_func,
                              model_ndim,
                              draws,
                              tune,
                              chains=chains,
                              cores=cores)
    assert np.var(trace) > 0
Exemplo n.º 7
0
def test_nuts_recovers_1d_normal():
    size = 1
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
    draws = 1000
    tune = 1000
    chains = 1
    cores = 1
    trace, stats = lmc.sample(
        logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores
    )

    assert np.allclose(np.mean(trace), 0, atol=1)
    assert np.allclose(np.std(trace), 1, atol=1)
Exemplo n.º 8
0
def test_multiprocess_sampling_runs():
    model_ndim = 1
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim)
    draws = 1
    tune = 1
    chains = 4
    cores = 4
    trace, stats = lmc.sample(logp_dlogp_func,
                              model_ndim,
                              draws,
                              tune,
                              step=step,
                              chains=chains,
                              cores=cores)
Exemplo n.º 9
0
def test_nuts_tuning():
    model_ndim = 1
    draws = 5
    tune = 5
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim)
    chains = 1
    cores = 1
    trace, stats = lmc.sample(logp_dlogp_func,
                              model_ndim,
                              draws,
                              tune,
                              step=step,
                              chains=chains,
                              cores=cores)

    assert not step.tune
Exemplo n.º 10
0
def test_nuts_tuning():
    size = 1
    draws = 5
    tune = 5
    step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size)
    chains = 1
    cores = 1
    trace, stats = lmc.sample(logp_dlogp_func,
                              size,
                              draws,
                              tune,
                              step=step,
                              chains=chains,
                              cores=cores)

    assert not step.tune
Exemplo n.º 11
0
def test_hmc_recovers_1d_normal():
    model_ndim = 1
    step = lmc.HamiltonianMC(logp_dlogp_func=logp_dlogp_func,
                             model_ndim=model_ndim)
    draws = 1000
    tune = 1000
    chains = 1
    cores = 1
    trace, stats = lmc.sample(logp_dlogp_func,
                              model_ndim,
                              draws,
                              tune,
                              step=step,
                              chains=chains,
                              cores=cores)

    assert np.allclose(np.mean(trace), 0, atol=1)
    assert np.allclose(np.std(trace), 1, atol=1)
torch_model = torch.jit.script(LinearModel())
torch_params = [torch_model.m, torch_model.b, torch_model.logs]
args = [
    torch.tensor(x, dtype=torch.double),
    torch.tensor(y_obs, dtype=torch.double)
]


def torch_logp_dlogp_func(x):
    for i, p in enumerate(torch_params):
        p.data = torch.tensor(x[i])
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()

    result = torch_model(*args)
    result.backward()

    return result.detach().numpy(), np.array(
        [p.grad.numpy() for p in torch_params])


trace, stats = lmc.sample(
    logp_dlogp_func=torch_logp_dlogp_func,
    model_ndim=3,
    tune=500,
    draws=1000,
    chains=4,
)