Exemplo n.º 1
0
 def test_all_escape(self):
     try:
         poutine.escape(self.model,
                        functools.partial(all_escape, poutine.Trace()))()
         assert False
     except NonlocalExit as e:
         assert e.site["name"] == "x"
Exemplo n.º 2
0
 def test_all_escape(self):
     try:
         poutine.escape(self.model, functools.partial(all_escape,
                                                      poutine.Trace()))()
         assert False
     except NonlocalExit as e:
         assert e.site["name"] == "x"
Exemplo n.º 3
0
 def test_discrete_escape(self):
     try:
         poutine.escape(self.model,
                        escape_fn=functools.partial(discrete_escape,
                                                    poutine.Trace()))()
         assert False
     except NonlocalExit as e:
         assert e.site["name"] == "y"
Exemplo n.º 4
0
 def test_discrete_escape(self):
     try:
         poutine.escape(self.model,
                        escape_fn=functools.partial(discrete_escape,
                                                    poutine.Trace()))()
         assert False
     except NonlocalExit as e:
         assert e.site["name"] == "y"
Exemplo n.º 5
0
 def test_trace_compose(self):
     tm = poutine.trace(self.model)
     try:
         poutine.escape(tm, functools.partial(all_escape, poutine.Trace()))()
         assert False
     except NonlocalExit:
         assert "x" in tm.trace
         try:
             tem = poutine.trace(
                 poutine.escape(self.model, functools.partial(all_escape,
                                                              poutine.Trace())))
             tem()
             assert False
         except NonlocalExit:
             assert "x" not in tem.trace
Exemplo n.º 6
0
 def test_trace_compose(self):
     tm = poutine.trace(self.model)
     try:
         poutine.escape(tm, functools.partial(all_escape, poutine.Trace()))()
         assert False
     except NonlocalExit:
         assert "x" in tm.trace
         try:
             tem = poutine.trace(
                 poutine.escape(self.model, functools.partial(all_escape,
                                                              poutine.Trace())))
             tem()
             assert False
         except NonlocalExit:
             assert "x" not in tem.trace
Exemplo n.º 7
0
    def _fn(*args, **kwargs):

        for i in range(int(1e6)):
            assert (
                not queue.empty()
            ), "trying to get() from an empty queue will deadlock"

            priority, next_trace = queue.get()
            try:
                ftr = poutine.trace(
                    poutine.escape(
                        poutine.replay(fn, next_trace),
                        functools.partial(sample_escape, next_trace),
                    )
                )
                return ftr(*args, **kwargs)
            except NonlocalExit as site_container:
                site_container.reset_stack()
                for tr in poutine.util.enum_extend(
                    ftr.trace.copy(), site_container.site
                ):
                    # add a little bit of noise to the priority to break ties...
                    queue.put(
                        (tr.log_prob_sum().item() - torch.rand(1).item() * 1e-2, tr)
                    )

        raise ValueError("max tries ({}) exceeded".format(str(1e6)))
Exemplo n.º 8
0
def iter_discrete_traces(graph_type, fn, *args, **kwargs):
    """
    Iterate over all discrete choices of a stochastic function.

    When sampling continuous random variables, this behaves like `fn`.
    When sampling discrete random variables, this iterates over all choices.

    This yields `(scale, trace)` pairs, where `scale` is the probability of the
    discrete choices made in the `trace`.

    :param str graph_type: The type of the graph, e.g. "flat" or "dense".
    :param callable fn: A stochastic function.
    :returns: An iterator over (scale, trace) pairs.
    """
    queue = LifoQueue()
    queue.put(Trace())
    while not queue.empty():
        partial_trace = queue.get()
        escape_fn = functools.partial(util.discrete_escape, partial_trace)
        traced_fn = poutine.trace(poutine.escape(poutine.replay(fn, partial_trace), escape_fn),
                                  graph_type=graph_type)
        try:
            full_trace = traced_fn.get_trace(*args, **kwargs)
        except util.NonlocalExit as e:
            for extended_trace in util.enum_extend(traced_fn.trace.copy(), e.site):
                queue.put(extended_trace)
            continue

        # Scale trace by probability of discrete choices.
        log_pdf = full_trace.batch_log_pdf(site_filter=site_is_discrete)
        if isinstance(log_pdf, float):
            log_pdf = torch.Tensor([log_pdf])
        if isinstance(log_pdf, torch.Tensor):
            log_pdf = Variable(log_pdf)
        scale = torch.exp(log_pdf.detach())
        yield scale, full_trace
Exemplo n.º 9
0
 def nested_model():
     pyro.sample("internal0", dist.Bernoulli(0.5))
     with poutine.escape(escape_fn=lambda msg: msg["name"] == "internal2"):
         pyro.sample("internal1", dist.Bernoulli(0.5))
         pyro.sample("internal2", dist.Bernoulli(0.5))
         pyro.sample("internal3", dist.Bernoulli(0.5))