def test_trace_compose(self): tm = poutine.trace(self.model) try: poutine.escape(tm, functools.partial(all_escape, poutine.Trace()))() assert False except NonlocalExit: assert "x" in tm.trace try: tem = poutine.trace( poutine.escape(self.model, functools.partial(all_escape, poutine.Trace()))) tem() assert False except NonlocalExit: assert "x" not in tem.trace
def _get_traces(self, model, guide, args, kwargs): """ Runs the guide and runs the model against the guide with the result packaged as a trace generator. """ if self.max_plate_nesting == float('inf'): self._guess_max_plate_nesting(model, guide, args, kwargs) if self.vectorize_particles: guide = self._vectorized_num_particles(guide) model = self._vectorized_num_particles(model) # Enable parallel enumeration over the vectorized guide and model. # The model allocates enumeration dimensions after (to the left of) the guide, # accomplished by preserving the _ENUM_ALLOCATOR state after the guide call. guide_enum = EnumMessenger(first_available_dim=-1 - self.max_plate_nesting) model_enum = EnumMessenger() # preserve _ENUM_ALLOCATOR state guide = guide_enum(guide) model = model_enum(model) q = queue.LifoQueue() guide = poutine.queue(guide, q, escape_fn=iter_discrete_escape, extend_fn=iter_discrete_extend) for i in range(1 if self.vectorize_particles else self.num_particles): q.put(poutine.Trace()) while not q.empty(): yield self._get_trace(model, guide, args, kwargs)
def setUp(self): # simple Gaussian-mixture HMM def model(): probs = pyro.param("probs", torch.tensor([[0.8], [0.3]])) loc = pyro.param("loc", torch.tensor([[-0.1], [0.9]])) scale = torch.ones(1, 1) latents = [torch.ones(1)] observes = [] for t in range(3): latents.append( pyro.sample("latent_{}".format(str(t)), Bernoulli(probs[latents[-1][0].long().data]))) observes.append( pyro.sample("observe_{}".format(str(t)), Normal(loc[latents[-1][0].long().data], scale), obs=torch.ones(1))) return latents self.sites = ["observe_{}".format(str(t)) for t in range(3)] + \ ["latent_{}".format(str(t)) for t in range(3)] + \ ["_INPUT", "_RETURN"] self.model = model self.queue = Queue() self.queue.put(poutine.Trace())
def _traces(self, *args, **kwargs): q = queue.Queue() q.put(poutine.Trace()) p = poutine.trace(poutine.queue(self.model, queue=q, max_tries=self.max_tries)) while not q.empty(): tr = p.get_trace(*args, **kwargs) yield tr, tr.log_prob_sum()
def test_all_escape(self): try: poutine.escape(self.model, functools.partial(all_escape, poutine.Trace()))() assert False except NonlocalExit as e: assert e.site["name"] == "x"
def setUp(self): # simple Gaussian-mixture HMM def model(): ps = pyro.param("ps", Variable(torch.Tensor([[0.8], [0.3]]))) mu = pyro.param("mu", Variable(torch.Tensor([[-0.1], [0.9]]))) sigma = Variable(torch.ones(1, 1)) latents = [Variable(torch.ones(1))] observes = [] for t in range(3): latents.append( pyro.sample("latent_{}".format(str(t)), Bernoulli(ps[latents[-1][0].long().data]))) observes.append( pyro.observe("observe_{}".format(str(t)), Normal(mu[latents[-1][0].long().data], sigma), pyro.ones(1))) return latents self.sites = ["observe_{}".format(str(t)) for t in range(3)] + \ ["latent_{}".format(str(t)) for t in range(3)] + \ ["_INPUT", "_RETURN"] self.model = model self.queue = Queue() self.queue.put(poutine.Trace())
def test_discrete_escape(self): try: poutine.escape(self.model, escape_fn=functools.partial(discrete_escape, poutine.Trace()))() assert False except NonlocalExit as e: assert e.site["name"] == "y"
def _traces(self, *args, **kwargs): q = queue.PriorityQueue() # add a little bit of noise to the priority to break ties... q.put((torch.zeros(1).item() - torch.rand(1).item() * 1e-2, poutine.Trace())) q_fn = pqueue(self.model, queue=q) for i in range(self.num_samples): if q.empty(): # num_samples was too large! break tr = poutine.trace(q_fn).get_trace(*args, **kwargs) # XXX should block yield tr, tr.log_prob_sum()
def setUp(self): # Simple model with 1 continuous + 1 discrete + 1 continuous variable. def model(): p = torch.tensor([0.5]) loc = torch.zeros(1) scale = torch.ones(1) x = pyro.sample("x", Normal(loc, scale)) # Before the discrete variable. y = pyro.sample("y", Bernoulli(p)) z = pyro.sample("z", Normal(loc, scale)) # After the discrete variable. return dict(x=x, y=y, z=z) self.sites = ["x", "y", "z", "_INPUT", "_RETURN"] self.model = model self.queue = Queue() self.queue.put(poutine.Trace())
def _traces(self, *args, **kwargs): """ algorithm entered here Running until the queue is empty and collecting the marginal histogram is performing exact inference :returns: Iterator of traces from the posterior. :rtype: Generator[:class:`pyro.Trace`] """ # currently only using the standard library queue self.queue = Queue() self.queue.put(poutine.Trace()) p = poutine.trace( poutine.queue(self.model, queue=self.queue, max_tries=self.max_tries)) while not self.queue.empty(): tr = p.get_trace(*args, **kwargs) yield (tr, tr.log_pdf())
def test_decorator_interface_queue(): sites = ["x", "y", "z", "_INPUT", "_RETURN"] queue = Queue() queue.put(poutine.Trace()) @poutine.queue(queue=queue) def model(): p = torch.tensor([0.5]) loc = torch.zeros(1) scale = torch.ones(1) x = pyro.sample("x", Normal(loc, scale)) y = pyro.sample("y", Bernoulli(p)) z = pyro.sample("z", Normal(loc, scale)) return dict(x=x, y=y, z=z) tr = poutine.trace(model).get_trace() for name in sites: assert name in tr
def register_model(**poutine_kwargs): """ Decorator to register a model as an example model for testing. """ def register_fn(fn): model = ExampleModel(fn, poutine_kwargs) EXAMPLE_MODELS.append(model) EXAMPLE_MODEL_IDS.append(model.fn.__name__) return model return register_fn @register_model(replay={'trace': poutine.Trace()}, block={}, condition={'data': {}}, do={'data': {}}) def trivial_model(): return [] tr_normal = poutine.Trace() tr_normal.add_node("normal_0", type="sample", is_observed=False, value=torch.zeros(1), infer={}) @register_model(replay={'trace': tr_normal}, block={'hide': ['normal_0']}, condition={'data': {'normal_0': torch.zeros(1)}}, do={'data': {'normal_0': torch.zeros(1)}})
def _sample_posterior_from_trace(model, enum_trace, temperature, *args, **kwargs): plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. sum_dims = set() queries = [] dim_to_size = {} cost_terms = OrderedDict() enum_terms = OrderedDict() for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized and f.size > 1) # For sites that depend on an enumerated variable, we need to apply # the mask but not the scale when sampling. if "masked_log_prob" not in node["packed"]: node["packed"]["masked_log_prob"] = packed.scale_and_mask( node["packed"]["unscaled_log_prob"], mask=node["packed"]["mask"]) log_prob = node["packed"]["masked_log_prob"] sum_dims.update(frozenset(log_prob._pyro_dims) - ordinal) if sum_dims.isdisjoint(log_prob._pyro_dims): continue dim_to_size.update(zip(log_prob._pyro_dims, log_prob.shape)) if node["infer"].get("_enumerate_dim") is None: cost_terms.setdefault(ordinal, []).append(log_prob) else: enum_terms.setdefault(ordinal, []).append(log_prob) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # We take special care to match the term ordering in # pyro.infer.traceenum_elbo._compute_model_factors() to allow # contract_tensor_tree() to use shared_intermediates() inside # TraceEnumSample_ELBO. The special ordering is: first all cost terms in # order of model_trace, then all enum_terms in order of model trace. log_probs = cost_terms for ordinal, terms in enum_terms.items(): log_probs.setdefault(ordinal, []).extend(terms) # Run forward-backward algorithm, collecting the ordinal of each connected component. cache = getattr(enum_trace, "_sharing_cache", {}) ring = _make_ring(temperature, cache, dim_to_size) with shared_intermediates(cache): log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["masked_log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not (f.vectorized and f.size > 1) or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): # For internal use by infer_discrete. # Create an enumerated trace. with poutine.block(), EnumerateMessenger(first_available_dim): enum_trace = poutine.trace(model).get_trace(*args, **kwargs) enum_trace = prune_subsample_sites(enum_trace) enum_trace.compute_log_prob() enum_trace.pack_tensors() plate_to_symbol = enum_trace.plate_to_symbol # Collect a set of query sample sites to which the backward algorithm will propagate. log_probs = OrderedDict() sum_dims = set() queries = [] for node in enum_trace.nodes.values(): if node["type"] == "sample": ordinal = frozenset(plate_to_symbol[f.name] for f in node["cond_indep_stack"] if f.vectorized) log_prob = node["packed"]["log_prob"] log_probs.setdefault(ordinal, []).append(log_prob) sum_dims.update(log_prob._pyro_dims) for frame in node["cond_indep_stack"]: if frame.vectorized: sum_dims.remove(plate_to_symbol[frame.name]) # Note we mark all sample sites with require_backward to gather # enumerated sites and adjust cond_indep_stack of all sample sites. if not node["is_observed"]: queries.append(log_prob) require_backward(log_prob) # Run forward-backward algorithm, collecting the ordinal of each connected component. ring = _make_ring(temperature) log_probs = contract_tensor_tree(log_probs, sum_dims, ring=ring) # run forward algorithm query_to_ordinal = {} pending = object() # a constant value for pending queries for query in queries: query._pyro_backward_result = pending for ordinal, terms in log_probs.items(): for term in terms: if hasattr(term, "_pyro_backward"): term._pyro_backward() # run backward algorithm # Note: this is quadratic in number of ordinals for query in queries: if query not in query_to_ordinal and query._pyro_backward_result is not pending: query_to_ordinal[query] = ordinal # Construct a collapsed trace by gathering and adjusting cond_indep_stack. collapsed_trace = poutine.Trace() for node in enum_trace.nodes.values(): if node["type"] == "sample" and not node["is_observed"]: # TODO move this into a Leaf implementation somehow new_node = { "type": "sample", "name": node["name"], "is_observed": False, "infer": node["infer"].copy(), "cond_indep_stack": node["cond_indep_stack"], "value": node["value"], } log_prob = node["packed"]["log_prob"] if hasattr(log_prob, "_pyro_backward_result"): # Adjust the cond_indep_stack. ordinal = query_to_ordinal[log_prob] new_node["cond_indep_stack"] = tuple( f for f in node["cond_indep_stack"] if not f.vectorized or plate_to_symbol[f.name] in ordinal) # Gather if node depended on an enumerated value. sample = log_prob._pyro_backward_result if sample is not None: new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"]) for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims): if dim in new_value._pyro_dims: index._pyro_dims = sample._pyro_dims[1:] new_value = packed.gather(new_value, index, dim) new_node["value"] = packed.unpack(new_value, enum_trace.symbol_to_dim) collapsed_trace.add_node(node["name"], **new_node) # Replay the model against the collapsed trace. with SamplePosteriorMessenger(trace=collapsed_trace): return model(*args, **kwargs)