def model(): def true_fun(_): x = numpyro.sample("x", dist.Normal(4.0)) numpyro.deterministic("z", x - 4.0) def false_fun(_): x = numpyro.sample("x", dist.Normal(0.0)) numpyro.deterministic("z", x) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None)
def guide(): m1 = numpyro.param("m1", 2.0) s1 = numpyro.param("s1", 0.1, constraint=dist.constraints.positive) m2 = numpyro.param("m2", 2.0) s2 = numpyro.param("s2", 0.1, constraint=dist.constraints.positive) def true_fun(_): numpyro.sample("x", dist.Normal(m1, s1)) def false_fun(_): numpyro.sample("x", dist.Normal(m2, s2)) cluster = numpyro.sample("cluster", dist.Normal()) cond(cluster > 0, true_fun, false_fun, None)