def test_all_escape(self): try: poutine.escape(self.model, functools.partial(all_escape, poutine.Trace()))() assert False except NonlocalExit as e: assert e.site["name"] == "x"
def test_discrete_escape(self): try: poutine.escape(self.model, escape_fn=functools.partial(discrete_escape, poutine.Trace()))() assert False except NonlocalExit as e: assert e.site["name"] == "y"
def test_trace_compose(self): tm = poutine.trace(self.model) try: poutine.escape(tm, functools.partial(all_escape, poutine.Trace()))() assert False except NonlocalExit: assert "x" in tm.trace try: tem = poutine.trace( poutine.escape(self.model, functools.partial(all_escape, poutine.Trace()))) tem() assert False except NonlocalExit: assert "x" not in tem.trace
def _fn(*args, **kwargs): for i in range(int(1e6)): assert ( not queue.empty() ), "trying to get() from an empty queue will deadlock" priority, next_trace = queue.get() try: ftr = poutine.trace( poutine.escape( poutine.replay(fn, next_trace), functools.partial(sample_escape, next_trace), ) ) return ftr(*args, **kwargs) except NonlocalExit as site_container: site_container.reset_stack() for tr in poutine.util.enum_extend( ftr.trace.copy(), site_container.site ): # add a little bit of noise to the priority to break ties... queue.put( (tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr) ) raise ValueError("max tries ({}) exceeded".format(str(1e6)))
def iter_discrete_traces(graph_type, fn, *args, **kwargs): """ Iterate over all discrete choices of a stochastic function. When sampling continuous random variables, this behaves like `fn`. When sampling discrete random variables, this iterates over all choices. This yields `(scale, trace)` pairs, where `scale` is the probability of the discrete choices made in the `trace`. :param str graph_type: The type of the graph, e.g. "flat" or "dense". :param callable fn: A stochastic function. :returns: An iterator over (scale, trace) pairs. """ queue = LifoQueue() queue.put(Trace()) while not queue.empty(): partial_trace = queue.get() escape_fn = functools.partial(util.discrete_escape, partial_trace) traced_fn = poutine.trace(poutine.escape(poutine.replay(fn, partial_trace), escape_fn), graph_type=graph_type) try: full_trace = traced_fn.get_trace(*args, **kwargs) except util.NonlocalExit as e: for extended_trace in util.enum_extend(traced_fn.trace.copy(), e.site): queue.put(extended_trace) continue # Scale trace by probability of discrete choices. log_pdf = full_trace.batch_log_pdf(site_filter=site_is_discrete) if isinstance(log_pdf, float): log_pdf = torch.Tensor([log_pdf]) if isinstance(log_pdf, torch.Tensor): log_pdf = Variable(log_pdf) scale = torch.exp(log_pdf.detach()) yield scale, full_trace
def nested_model(): pyro.sample("internal0", dist.Bernoulli(0.5)) with poutine.escape(escape_fn=lambda msg: msg["name"] == "internal2"): pyro.sample("internal1", dist.Bernoulli(0.5)) pyro.sample("internal2", dist.Bernoulli(0.5)) pyro.sample("internal3", dist.Bernoulli(0.5))