def test_nuts_sampling_runs(): model_ndim = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) draws = 3 tune = 1 chains = 2 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores) assert trace.shape == (chains, draws, model_ndim) assert all([ stats[name].shape == (chains, draws, model_ndim) for (name, _) in step.stats_dtypes[0].items() ]) assert all([ stats[name].dtype == expected_dtype for (name, expected_dtype) in step.stats_dtypes[0].items() ])
def test_multiprocess_sampling_runs(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 1 tune = 1 chains = None cores = None trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores )
def test_nuts_sampling_runs(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 2 tune = 1 chains = 1 cores = 1 trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores ) assert trace.shape == (1, 2)
def test_nuts_recovers_1d_normal(): size = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) draws = 1000 tune = 1000 chains = 1 cores = 1 trace, stats = lmc.sample( logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores ) assert np.allclose(np.mean(trace), 0, atol=1) assert np.allclose(np.std(trace), 1, atol=1)
def test_multiprocess_sampling_runs(): model_ndim = 1 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) draws = 1 tune = 1 chains = 4 cores = 4 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores)
def test_nuts_tuning(): model_ndim = 1 draws = 5 tune = 5 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, model_ndim=model_ndim) chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, model_ndim, draws, tune, step=step, chains=chains, cores=cores) assert not step.tune
def test_nuts_tuning(): size = 1 draws = 5 tune = 5 step = lmc.NUTS(logp_dlogp_func=logp_dlogp_func, size=size) chains = 1 cores = 1 trace, stats = lmc.sample(logp_dlogp_func, size, draws, tune, step=step, chains=chains, cores=cores) assert not step.tune