def sample(name, fn, *args, **kwargs): """ Calls the stochastic function ``fn`` with additional side-effects depending on ``name`` and the enclosing context (e.g. an inference algorithm). See `Introduction to Pyro <http://pyro.ai/examples/intro_long.html>`_ for a discussion. :param name: name of sample :param fn: distribution class or function :param obs: observed datum (optional; should only be used in context of inference) optionally specified in kwargs :param ~torch.Tensor obs_mask: Optional boolean tensor mask of shape broadcastable with ``fn.batch_shape``. If provided, events with mask=True will be conditioned on ``obs`` and remaining events will be imputed by sampling. This introduces a latent sample site named ``name + "_unobserved"`` which should be used by guides. :type obs_mask: bool or ~torch.Tensor :param dict infer: Optional dictionary of inference parameters specified in kwargs. See inference documentation for details. :returns: sample """ # Transform obs_mask into multiple sample statements. obs = kwargs.pop("obs", None) obs_mask = kwargs.pop("obs_mask", None) if obs_mask is not None: return _masked_observe(name, fn, obs, obs_mask, *args, **kwargs) # Check if stack is empty. # if stack empty, default behavior (defined here) infer = kwargs.pop("infer", {}).copy() is_observed = infer.pop("is_observed", obs is not None) if not am_i_wrapped(): if obs is not None and not infer.get("_deterministic"): warnings.warn( "trying to observe a value outside of inference at " + name, RuntimeWarning, ) return obs return fn(*args, **kwargs) # if stack not empty, apply everything in the stack? else: # initialize data structure to pass up/down the stack msg = { "type": "sample", "name": name, "fn": fn, "is_observed": is_observed, "args": args, "kwargs": kwargs, "value": obs, "infer": infer, "scale": 1.0, "mask": None, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None, } # apply the stack and return its return value apply_stack(msg) return msg["value"]
def param(name, *args, **kwargs): """ Saves the variable as a parameter in the param store. To interact with the param store or write to disk, see `Parameters <parameters.html>`_. :param name: name of parameter :returns: parameter """ if not am_i_wrapped(): return _PYRO_PARAM_STORE.get_param(name, *args, **kwargs) else: msg = { "type": "param", "name": name, "args": args, "kwargs": kwargs, "infer": {}, "scale": 1.0, "cond_indep_stack": (), "value": None, "done": False, "stop": False, "continuation": None } # apply the stack and return its return value apply_stack(msg) return msg["value"]
def sample(name, fn, *args, **kwargs): """ Calls the stochastic function `fn` with additional side-effects depending on `name` and the enclosing context (e.g. an inference algorithm). See `Intro I <http://pyro.ai/examples/intro_part_i.html>`_ and `Intro II <http://pyro.ai/examples/intro_part_ii.html>`_ for a discussion. :param name: name of sample :param fn: distribution class or function :param obs: observed datum (optional; should only be used in context of inference) optionally specified in kwargs :param dict infer: Optional dictionary of inference parameters specified in kwargs. See inference documentation for details. :returns: sample """ obs = kwargs.pop("obs", None) infer = kwargs.pop("infer", {}).copy() # check if stack is empty # if stack empty, default behavior (defined here) if not am_i_wrapped(): if obs is not None: warnings.warn( "trying to observe a value outside of inference at " + name, RuntimeWarning) return obs return fn(*args, **kwargs) # if stack not empty, apply everything in the stack? else: # initialize data structure to pass up/down the stack msg = { "type": "sample", "name": name, "fn": fn, "is_observed": False, "args": args, "kwargs": kwargs, "value": None, "infer": infer, "scale": 1.0, "mask": None, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None } # handle observation if obs is not None: msg["value"] = obs msg["is_observed"] = True # apply the stack and return its return value apply_stack(msg) return msg["value"]
def sample(name, fn, *args, **kwargs): """ Calls the stochastic function `fn` with additional side-effects depending on `name` and the enclosing context (e.g. an inference algorithm). See `Intro I <http://pyro.ai/examples/intro_part_i.html>`_ and `Intro II <http://pyro.ai/examples/intro_part_ii.html>`_ for a discussion. :param name: name of sample :param fn: distribution class or function :param obs: observed datum (optional; should only be used in context of inference) optionally specified in kwargs :param dict infer: Optional dictionary of inference parameters specified in kwargs. See inference documentation for details. :returns: sample """ obs = kwargs.pop("obs", None) infer = kwargs.pop("infer", {}) # check if stack is empty # if stack empty, default behavior (defined here) if not am_i_wrapped(): if obs is not None: warnings.warn("trying to observe a value outside of inference at " + name, RuntimeWarning) return obs return fn(*args, **kwargs) # if stack not empty, apply everything in the stack? else: # initialize data structure to pass up/down the stack msg = { "type": "sample", "name": name, "fn": fn, "is_observed": False, "args": args, "kwargs": kwargs, "value": None, "infer": infer, "scale": 1.0, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None } # handle observation if obs is not None: msg["value"] = obs msg["is_observed"] = True # apply the stack and return its return value apply_stack(msg) return msg["value"]