Пример #1
0
def test_rng_seed(backend):
    def model():
        return pyro.sample("x", dist.Normal(0, 1))

    with pyro_backend(backend):
        with handlers.seed(rng_seed=0):
            expected = model()
        with handlers.seed(rng_seed=0):
            actual = model()
        assert ops.allclose(actual, expected)
Пример #2
0
def test_model_sample(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=2):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get(
            'model_args', ()), f.get('model_kwargs', {})
        model(*model_args, **model_kwargs)
Пример #3
0
def test_trace_handler(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=2):
        f = MODELS[model]()
        model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        # should be implemented
        handlers.trace(model).get_trace(*model_args, **model_kwargs)
Пример #4
0
def test_mcmc_interface(model, backend):
    pytest.importorskip(PACKAGE_NAME[backend])
    with pyro_backend(backend), handlers.seed(rng_seed=20):
        f = MODELS[model]()
        model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {})
        nuts_kernel = infer.NUTS(model=model)
        mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10)
        mcmc.run(*args, **kwargs)
        mcmc.summary()