コード例 #1
0
def assert_ok(model, *args, **kwargs):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.get_param_store().clear()
    kernel = infer.NUTS(model)
    mcmc = infer.MCMC(kernel, num_samples=2, warmup_steps=2)
    mcmc.run(*args, **kwargs)
コード例 #2
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
def assert_ok(model, guide, elbo, *args, **kwargs):
    """
    Assert that inference works without warnings or errors.
    """
    pyro.get_param_store().clear()
    adam = optim.Adam({"lr": 1e-6})
    inference = infer.SVI(model, guide, adam, elbo)
    for i in range(2):
        inference.step(*args, **kwargs)
コード例 #3
0
ファイル: test_svi.py プロジェクト: feynmanliang/pyro-api
def test_optimizer(backend, optim_name, optim_kwargs, jit):
    def model(data):
        p = pyro.param("p", ops.tensor(0.5))
        pyro.sample("x", dist.Bernoulli(p), obs=data)

    def guide(data):
        pass

    data = ops.tensor(0.)
    pyro.get_param_store().clear()
    Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
    elbo = Elbo(ignore_jit_warnings=True)
    optimizer = getattr(optim, optim_name)(optim_kwargs.copy())
    inference = infer.SVI(model, guide, optimizer, elbo)
    for i in range(2):
        inference.step(data)