def test_nuts_sampling_runs(): model_ndim = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) draws = 3 tune = 1 chains = 2 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores) assert trace.shape == (chains, draws, model_ndim) assert all([ stats[name].shape == (chains, draws, model_ndim) for (name, _) in step.stats_dtypes[0].items() ]) assert all([ stats[name].dtype == expected_dtype for (name, expected_dtype) in step.stats_dtypes[0].items() ])
def test_multiprocessing_with_various_frameworks(framework): logp_dlogp_funcs = { "pytorch": torch_logp_dlogp_func, "jax": jax_logp_dlogp_func, "pymc3": pm_logp_dlogp_func, } logp_dlogp_func = logp_dlogp_funcs[framework] model_ndim = 3 tune = 10 draws = 10 chains = 4 cores = 4 trace, stats = lmc.sample( logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim, tune=tune, draws=draws, chains=chains, cores=cores, progressbar=False, ) assert trace.shape == (chains, draws, model_ndim)
def test_reset_tuning(): model_ndim = 1 draws = 2 tune = 50 chains = 2 start, step = lmc.init_nuts(logp_dlogp_func=logp_dlogp_func, model_ndim=1) cores = 1 lmc.sample( logp_dlogp_func, model_ndim, draws=draws, tune=tune, chains=chains, step=step, start=start, cores=cores, ) assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1
def test_multiprocess_sampling_runs(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 1 tune = 1 chains = None cores = None trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores )
def test_nuts_sampling_runs(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 2 tune = 1 chains = 1 cores = 1 trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores ) assert trace.shape == (1, 2)
def test_samples_not_all_same(): model_ndim = 1 draws = 50 tune = 10 chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, chains=chains, cores=cores) assert np.var(trace) > 0
def test_nuts_recovers_1d_normal(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 1000 tune = 1000 chains = 1 cores = 1 trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores ) assert np.allclose(np.mean(trace), 0, atol=1) assert np.allclose(np.std(trace), 1, atol=1)
def test_multiprocess_sampling_runs(): model_ndim = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) draws = 1 tune = 1 chains = 4 cores = 4 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores)
def test_nuts_tuning(): model_ndim = 1 draws = 5 tune = 5 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores) assert not step.tune
def test_nuts_tuning(): size = 1 draws = 5 tune = 5 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores) assert not step.tune
def test_hmc_recovers_1d_normal(): model_ndim = 1 step = lmc.HamiltonianMC(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) draws = 1000 tune = 1000 chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores) assert np.allclose(np.mean(trace), 0, atol=1) assert np.allclose(np.std(trace), 1, atol=1)
torch_model = torch.jit.script(LinearModel()) torch_params = [torch_model.m, torch_model.b, torch_model.logs] args = [ torch.tensor(x, dtype=torch.double), torch.tensor(y_obs, dtype=torch.double) ] def torch_logp_dlogp_func(x): for i, p in enumerate(torch_params): p.data = torch.tensor(x[i]) if p.grad is not None: p.grad.detach_() p.grad.zero_() result = torch_model(*args) result.backward() return result.detach().numpy(), np.array( [p.grad.numpy() for p in torch_params]) trace, stats = lmc.sample( logp_dlogp_func=torch_logp_dlogp_func, model_ndim=3, tune=500, draws=1000, chains=4, )