def test_prob(nderivs): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = torch.tensor([0.5, 1., 1.5]) p = pyro.param("p", torch.tensor(0.25)) @config_enumerate def model(num_particles): p = pyro.param("p") with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p)) with pyro.plate("data", 3): pyro.sample("x", dist.Normal(z, 1.), obs=data) def guide(num_particles): pass elbo = TraceEnum_ELBO(max_plate_nesting=2) expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1) posterior_model = infer_discrete(config_enumerate(model, "parallel"), first_available_dim=-3) posterior_trace = poutine.trace(posterior_model).get_trace( num_particles=num_particles) actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2) if nderivs == 0: assert_equal(expected_logprob, actual_logprob, prec=1e-3) elif nderivs == 1: expected_grad = grad(expected_logprob, [p])[0] actual_grad = grad(actual_logprob, [p])[0] assert_equal(expected_grad, actual_grad, prec=1e-3)
def test_hmm_smoke(infer, temperature, length): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append( pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) data[t] = pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states, data true_states, data = hmm([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer(config_enumerate(hmm), first_available_dim=-1, temperature=temperature) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model) self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._plates = {} for name, site in self.prototype_trace.iter_stochastic_nodes(): if site["infer"].get("enumerate") != "parallel": raise NotImplementedError('Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._plates[frame.name] = frame else: raise NotImplementedError("AutoDiscreteParallel does not support sequential pyro.plate")
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure model = config_enumerate(self.model, default="parallel") self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs) self.prototype_trace = prune_subsample_sites(self.prototype_trace) if self.master is not None: self.master()._check_prototype(self.prototype_trace) self._discrete_sites = [] self._cond_indep_stacks = {} self._iaranges = {} for name, site in self.prototype_trace.nodes.items(): if site["type"] != "sample" or site["is_observed"]: continue if site["infer"].get("enumerate") != "parallel": raise NotImplementedError('Expected sample site "{}" to be discrete and ' 'configured for parallel enumeration'.format(name)) # collect discrete sample sites fn = site["fn"] Dist = type(fn) if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical): params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])] else: raise NotImplementedError("{} is not supported".format(Dist.__name__)) self._discrete_sites.append((site, Dist, params)) # collect independence contexts self._cond_indep_stacks[name] = site["cond_indep_stack"] for frame in site["cond_indep_stack"]: if frame.vectorized: self._iaranges[frame.name] = frame else: raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
def test_warning(): data = torch.randn(4) def model(): x = pyro.sample("x", dist.Categorical(torch.ones(3))) with pyro.plate("data", len(data)): pyro.sample("obs", dist.Normal(x.float(), 1), obs=data) model_1 = infer_discrete(model, first_available_dim=-2) model_2 = infer_discrete(model, first_available_dim=-2, strict_enumeration_warning=False) model_3 = infer_discrete(config_enumerate(model), first_available_dim=-2) # model_1 should raise warnings. with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") model_1() assert w, 'No warnings were raised' # model_2 and model_3 should both be valid. model_2() model_3()