def test_scope(): def fn(): return numpyro.sample('x', dist.Normal()) with handlers.trace() as trace: with handlers.seed(rng_seed=1): with handlers.scope(prefix='a'): fn() with handlers.scope(prefix='b'): with handlers.scope(prefix='a'): fn() assert 'a/x' in trace assert 'b/a/x' in trace
def test_scope(): def fn(): return numpyro.sample("x", dist.Normal()) with handlers.trace() as trace: with handlers.seed(rng_seed=1): with handlers.scope(prefix="a"): fn() with handlers.scope(prefix="b"): with handlers.scope(prefix="a"): fn() assert "a/x" in trace assert "b/a/x" in trace
def birthdays_model( x, day_of_week, day_of_year, memorial_days_indicator, labour_days_indicator, thanksgiving_days_indicator, w0, L, M1, M2, M3, y=None, ): intercept = sample("intercept", dist.Normal(0, 1)) f1 = scope(trend_gp, "trend")(x, L, M1) f2 = scope(year_gp, "year")(x, w0, M2) g3 = scope(trend_gp, "week-trend")(x, L, M3) # length ~ lognormal(-1, 1) in original weekday = scope(weekday_effect, "week")(day_of_week) yearday = scope(yearday_effect, "day")(day_of_year) # # --- special days memorial = scope(special_effect, "memorial")(memorial_days_indicator) labour = scope(special_effect, "labour")(labour_days_indicator) thanksgiving = scope(special_effect, "thanksgiving")(thanksgiving_days_indicator) day = yearday + memorial + labour + thanksgiving # --- Combine components f = deterministic("f", intercept + f1 + f2 + jnp.exp(g3) * weekday + day) sigma = sample("sigma", dist.HalfNormal(0.5)) with plate("obs", x.shape[0]): sample("y", dist.Normal(f, sigma), obs=y)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
def test_scope_frames(): def model(y): mu = numpyro.sample("mu", dist.Normal()) sigma = numpyro.sample("sigma", dist.HalfNormal()) with numpyro.plate("plate1", y.shape[0]): numpyro.sample("y", dist.Normal(mu, sigma), obs=y) scope_prefix = "scope" scoped_model = handlers.scope(model, prefix=scope_prefix) obs = np.random.normal(size=(10,)) trace = handlers.trace(handlers.seed(model, 0)).get_trace(obs) scoped_trace = handlers.trace(handlers.seed(scoped_model, 0)).get_trace(obs) assert trace["y"]["cond_indep_stack"][0].name in trace assert scoped_trace[f"{scope_prefix}/y"]["cond_indep_stack"][0].name in scoped_trace