def test_lift(): def model(): loc1 = numpyro.param("loc1", 0.) scale1 = numpyro.param("scale1", 1., constraint=constraints.positive) numpyro.sample("latent1", dist.Normal(loc1, scale1)) loc2 = numpyro.param("loc2", 1.) scale2 = numpyro.param("scale2", 2., constraint=constraints.positive) latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2)) return latent2 loc1_prior = dist.Normal() scale1_prior = dist.LogNormal() prior = {"loc1": loc1_prior, "scale1": scale1_prior} with handlers.trace() as tr: with handlers.seed(rng_seed=1): model() with handlers.trace() as lifted_tr: with handlers.seed(rng_seed=2): with handlers.lift(prior=prior): model() for name in tr.keys(): assert name in lifted_tr if name in prior: assert lifted_tr[name]['fn'] is prior[name] assert lifted_tr[name]['type'] == 'sample' assert lifted_tr[name]['value'] not in (0., 1.) elif name in ('loc2', 'scale2'): assert lifted_tr[name]['type'] == 'param'
def test_lift(): def model(): loc1 = numpyro.param("loc1", 0.0) scale1 = numpyro.param("scale1", 1.0, constraint=constraints.positive) numpyro.sample("latent1", dist.Normal(loc1, scale1)) loc2 = numpyro.param("loc2", 1.0) scale2 = numpyro.param("scale2", 2.0, constraint=constraints.positive) latent2 = numpyro.sample("latent2", dist.Normal(loc2, scale2)) return latent2 loc1_prior = dist.Normal() scale1_prior = dist.LogNormal() prior = {"loc1": loc1_prior, "scale1": scale1_prior} with handlers.trace() as tr: with handlers.seed(rng_seed=1): model() with handlers.trace() as lifted_tr: with handlers.seed(rng_seed=2): with handlers.lift(prior=prior): model() for name in tr.keys(): assert name in lifted_tr if name in prior: assert lifted_tr[name]["fn"] is prior[name] assert lifted_tr[name]["type"] == "sample" assert lifted_tr[name]["value"] not in (0.0, 1.0) elif name in ("loc2", "scale2"): assert lifted_tr[name]["type"] == "param"
def test_lift_memoize(): def model(): a = numpyro.param("loc") b = numpyro.param("loc") assert a == b with handlers.seed(rng_seed=1): with handlers.lift(prior=dist.Normal(0, 1)): model()