Example #1
0
def test_prob(nderivs):
    #      +-------+
    #  z --|--> x  |
    #      +-------+
    num_particles = 10000
    data = torch.tensor([0.5, 1., 1.5])
    p = pyro.param("p", torch.tensor(0.25))

    @config_enumerate
    def model(num_particles):
        p = pyro.param("p")
        with pyro.plate("num_particles", num_particles, dim=-2):
            z = pyro.sample("z", dist.Bernoulli(p))
            with pyro.plate("data", 3):
                pyro.sample("x", dist.Normal(z, 1.), obs=data)

    def guide(num_particles):
        pass

    elbo = TraceEnum_ELBO(max_plate_nesting=2)
    expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1)

    posterior_model = infer_discrete(config_enumerate(model, "parallel"),
                                     first_available_dim=-3)
    posterior_trace = poutine.trace(posterior_model).get_trace(
        num_particles=num_particles)
    actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2)

    if nderivs == 0:
        assert_equal(expected_logprob, actual_logprob, prec=1e-3)
    elif nderivs == 1:
        expected_grad = grad(expected_logprob, [p])[0]
        actual_grad = grad(actual_logprob, [p])[0]
        assert_equal(expected_grad, actual_grad, prec=1e-3)
Example #2
0
def test_hmm_smoke(infer, temperature, length):

    # This should match the example in the infer_discrete docstring.
    def hmm(data, hidden_dim=10):
        transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
        means = torch.arange(float(hidden_dim))
        states = [0]
        for t in pyro.markov(range(len(data))):
            states.append(
                pyro.sample("states_{}".format(t),
                            dist.Categorical(transition[states[-1]])))
            data[t] = pyro.sample("obs_{}".format(t),
                                  dist.Normal(means[states[-1]], 1.),
                                  obs=data[t])
        return states, data

    true_states, data = hmm([None] * length)
    assert len(data) == length
    assert len(true_states) == 1 + len(data)

    decoder = infer(config_enumerate(hmm),
                    first_available_dim=-1,
                    temperature=temperature)
    inferred_states, _ = decoder(data)
    assert len(inferred_states) == len(true_states)

    logger.info("true states: {}".format(list(map(int, true_states))))
    logger.info("inferred states: {}".format(list(map(int, inferred_states))))
Example #3
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._plates = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._plates[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support sequential pyro.plate")
Example #4
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model, default="parallel")
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
Example #5
0
def test_warning():
    data = torch.randn(4)

    def model():
        x = pyro.sample("x", dist.Categorical(torch.ones(3)))
        with pyro.plate("data", len(data)):
            pyro.sample("obs", dist.Normal(x.float(), 1), obs=data)

    model_1 = infer_discrete(model, first_available_dim=-2)
    model_2 = infer_discrete(model,
                             first_available_dim=-2,
                             strict_enumeration_warning=False)
    model_3 = infer_discrete(config_enumerate(model), first_available_dim=-2)

    # model_1 should raise warnings.
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        model_1()
    assert w, 'No warnings were raised'

    # model_2 and model_3 should both be valid.
    model_2()
    model_3()