def test_rng_seed(backend): def model(): return pyro.sample("x", dist.Normal(0, 1)) with pyro_backend(backend): with handlers.seed(rng_seed=0): expected = model() with handlers.seed(rng_seed=0): actual = model() assert ops.allclose(actual, expected)
def test_model_sample(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=2): f = MODELS[model]() model, model_args, model_kwargs = f['model'], f.get( 'model_args', ()), f.get('model_kwargs', {}) model(*model_args, **model_kwargs)
def test_trace_handler(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=2): f = MODELS[model]() model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) # should be implemented handlers.trace(model).get_trace(*model_args, **model_kwargs)
def test_mcmc_interface(model, backend): pytest.importorskip(PACKAGE_NAME[backend]) with pyro_backend(backend), handlers.seed(rng_seed=20): f = MODELS[model]() model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) nuts_kernel = infer.NUTS(model=model) mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10) mcmc.run(*args, **kwargs) mcmc.summary()