Example #1
0
def test_lift():
    def model():
        loc1 = numpyro.param("loc1", 0.)
        scale1 = numpyro.param("scale1", 1., constraint=constraints.positive)
        numpyro.sample("latent1", dist.Normal(loc1, scale1))

        loc2 = numpyro.param("loc2", 1.)
        scale2 = numpyro.param("scale2", 2., constraint=constraints.positive)
        latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2))
        return latent2

    loc1_prior = dist.Normal()
    scale1_prior = dist.LogNormal()
    prior = {"loc1": loc1_prior, "scale1": scale1_prior}

    with handlers.trace() as tr:
        with handlers.seed(rng_seed=1):
            model()

    with handlers.trace() as lifted_tr:
        with handlers.seed(rng_seed=2):
            with handlers.lift(prior=prior):
                model()

    for name in tr.keys():
        assert name in lifted_tr
        if name in prior:
            assert lifted_tr[name]['fn'] is prior[name]
            assert lifted_tr[name]['type'] == 'sample'
            assert lifted_tr[name]['value'] not in (0., 1.)
        elif name in ('loc2', 'scale2'):
            assert lifted_tr[name]['type'] == 'param'
Example #2
0
def test_lift():
    def model():
        loc1 = numpyro.param("loc1", 0.0)
        scale1 = numpyro.param("scale1", 1.0, constraint=constraints.positive)
        numpyro.sample("latent1", dist.Normal(loc1, scale1))

        loc2 = numpyro.param("loc2", 1.0)
        scale2 = numpyro.param("scale2", 2.0, constraint=constraints.positive)
        latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2))
        return latent2

    loc1_prior = dist.Normal()
    scale1_prior = dist.LogNormal()
    prior = {"loc1": loc1_prior, "scale1": scale1_prior}

    with handlers.trace() as tr:
        with handlers.seed(rng_seed=1):
            model()

    with handlers.trace() as lifted_tr:
        with handlers.seed(rng_seed=2):
            with handlers.lift(prior=prior):
                model()

    for name in tr.keys():
        assert name in lifted_tr
        if name in prior:
            assert lifted_tr[name]["fn"] is prior[name]
            assert lifted_tr[name]["type"] == "sample"
            assert lifted_tr[name]["value"] not in (0.0, 1.0)
        elif name in ("loc2", "scale2"):
            assert lifted_tr[name]["type"] == "param"
Example #3
0
def test_lift_memoize():
    def model():
        a = numpyro.param("loc")
        b = numpyro.param("loc")
        assert a == b

    with handlers.seed(rng_seed=1):
        with handlers.lift(prior=dist.Normal(0, 1)):
            model()