def config_enumerate(fn, default='parallel'): """ Configures enumeration for all relevant sites in a NumPyro model. When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies ``.has_enumerate_support == True``. This can be used as either a function:: model = config_enumerate(model) or as a decorator:: @config_enumerate def model(*args, **kwargs): ... .. note:: Currently, only ``default='parallel'`` is supported. :param callable fn: Python callable with NumPyro primitives. :param str default: Which enumerate strategy to use, one of "sequential", "parallel", or None. Defaults to "parallel". """ def config_fn(site): if site['type'] == 'sample' and (not site['is_observed']) \ and site['fn'].has_enumerate_support: return {'enumerate': site['infer'].get('enumerate', default)} return {} return infer_config(fn, config_fn)
def config_enumerate(fn=None, default="parallel"): """ Configures enumeration for all relevant sites in a NumPyro model. When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose distribution satisfies ``.has_enumerate_support == True``. This can be used as either a function:: model = config_enumerate(model) or as a decorator:: @config_enumerate def model(*args, **kwargs): ... .. note:: Currently, only ``default='parallel'`` is supported. :param callable fn: Python callable with NumPyro primitives. :param str default: Which enumerate strategy to use, one of "sequential", "parallel", or None. Defaults to "parallel". """ if fn is None: # support use as a decorator return functools.partial(config_enumerate, default=default) def config_fn(site): if ( site["type"] == "sample" and (not site["is_observed"]) and site["fn"].has_enumerate_support ): return {"enumerate": site["infer"].get("enumerate", default)} return {} return infer_config(fn, config_fn)