def test_nested_plate(): with enum(first_available_dim=-3): with enum_plate("a", 5): with enum_plate("b", 2): x = numpyro.sample("x", dist.Normal(0, 1), rng_key=random.PRNGKey(0)) assert x.shape == (2, 5)
def plate_to_enum_plate(): """ A context manager to replace `numpyro.plate` statement by a funsor-based :class:`~numpyro.contrib.funsor.enum_messenger.plate`. This is useful when doing inference for the usual NumPyro programs with `numpyro.plate` statements. For example, to get trace of a `model` whose discrete latent sites are enumerated, we can use enum_model = numpyro.contrib.funsor.enum(model) with plate_to_enum_plate(): model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace( *model_args, **model_kwargs) """ try: numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs) yield finally: numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)