Ejemplo n.º 1
0
def test_gmm_iter_discrete_traces(data_size, graph_type, model):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data, verbose=True))
    # This non-vectorized version is exponential in data_size:
    assert len(traces) == 2**data_size
Ejemplo n.º 2
0
def test_gmm_batch_iter_discrete_traces(model, data_size, graph_type):
    pyro.clear_param_store()
    data = torch.arange(0, data_size)
    model = config_enumerate(model)
    traces = list(iter_discrete_traces(graph_type, model, data=data))
    # This vectorized version is independent of data_size:
    assert len(traces) == 2
Ejemplo n.º 3
0
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
        ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
                                                     [0.4, 0.3, 0.2, 0.1]])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        assert x.size() == (2, 1)
        assert y.size() == (2, 1)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * ps.size(-1)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.squeeze().long()[0]
        y = trace.nodes["y"]["value"].data.squeeze().long()[0]
        expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
                                   dist.Categorical(ps, one_hot=False).log_pdf(y))
        expected_scale = expected_scale.data.view(-1)[0]
        assert_equal(scale, expected_scale)
Ejemplo n.º 4
0
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
        ps = pyro.param(
            "ps",
            Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2,
                                                          0.1]])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        assert x.size() == (2, 1)
        assert y.size() == (2, 1)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * ps.size(-1)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.squeeze().long()[0]
        y = trace.nodes["y"]["value"].data.squeeze().long()[0]
        expected_scale = torch.exp(
            dist.Bernoulli(p).log_pdf(x) *
            dist.Categorical(ps, one_hot=False).log_pdf(y))
        expected_scale = expected_scale.data.view(-1)[0]
        assert_equal(scale, expected_scale)
Ejemplo n.º 5
0
def test_iter_discrete_traces_order(depth, graph_type):

    @config_enumerate
    def model(depth):
        for i in range(depth):
            pyro.sample("x{}".format(i), dist.Bernoulli(0.5))

    traces = list(iter_discrete_traces(graph_type, model, depth))

    assert len(traces) == 2 ** depth
    for trace in traces:
        sites = [name for name, site in trace.nodes.items() if site["type"] == "sample"]
        assert sites == ["x{}".format(i) for i in range(depth)]
Ejemplo n.º 6
0
def test_iter_discrete_traces_scalar(graph_type):
    pyro.clear_param_store()

    @config_enumerate
    def model():
        p = pyro.param("p", torch.tensor(0.05))
        probs = pyro.param("probs", torch.tensor([0.1, 0.2, 0.3, 0.4]))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(probs))
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    probs = pyro.param("probs")
    assert len(traces) == 2 * len(probs)
Ejemplo n.º 7
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
Ejemplo n.º 8
0
def test_iter_discrete_traces_vector(graph_type):
    pyro.clear_param_store()

    @config_enumerate
    def model():
        p = pyro.param("p", torch.tensor([0.05, 0.15]))
        probs = pyro.param("probs", torch.tensor([[0.1, 0.2, 0.3, 0.4],
                                                  [0.4, 0.3, 0.2, 0.1]]))
        with pyro.iarange("iarange", 2):
            x = pyro.sample("x", dist.Bernoulli(p))
            y = pyro.sample("y", dist.Categorical(probs))
        assert x.size() == (2,)
        assert y.size() == (2,)
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    probs = pyro.param("probs")
    assert len(traces) == 2 * probs.size(-1)
Ejemplo n.º 9
0
def test_iter_discrete_traces_scalar(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([0.05])))
        ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * len(ps)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.long().view(-1)[0]
        y = trace.nodes["y"]["value"].data.long().view(-1)[0]
        expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]]))
        assert_equal(scale, expected_scale)
Ejemplo n.º 10
0
def test_iter_discrete_traces_scalar(graph_type):
    pyro.clear_param_store()

    def model():
        p = pyro.param("p", Variable(torch.Tensor([0.05])))
        ps = pyro.param("ps", Variable(torch.Tensor([0.1, 0.2, 0.3, 0.4])))
        x = pyro.sample("x", dist.Bernoulli(p))
        y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
        return dict(x=x, y=y)

    traces = list(iter_discrete_traces(graph_type, model))

    p = pyro.param("p").data
    ps = pyro.param("ps").data
    assert len(traces) == 2 * len(ps)

    for scale, trace in traces:
        x = trace.nodes["x"]["value"].data.long().view(-1)[0]
        y = trace.nodes["y"]["value"].data.long().view(-1)[0]
        expected_scale = Variable(torch.Tensor([[1 - p[0], p[0]][x] * ps[y]]))
        assert_equal(scale, expected_scale)
Ejemplo n.º 11
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                # This iterates over a bag of traces, for each particle.
                for scale, guide_trace in iter_discrete_traces(
                        "flat", guide, *args, **kwargs):
                    model_trace = poutine.trace(
                        poutine.replay(model, guide_trace),
                        graph_type="flat").get_trace(*args, **kwargs)

                    check_model_guide_match(model_trace, guide_trace)
                    guide_trace = prune_subsample_sites(guide_trace)
                    model_trace = prune_subsample_sites(model_trace)
                    check_enum_discrete_can_run(model_trace, guide_trace)

                    log_r = model_trace.batch_log_pdf(
                    ) - guide_trace.batch_log_pdf()
                    weight = scale / self.num_particles
                    yield weight, model_trace, guide_trace, log_r
                continue

            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            log_r = model_trace.log_pdf() - guide_trace.log_pdf()
            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace, log_r
Ejemplo n.º 12
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Ejemplo n.º 13
0
        def lp_fn(input_dict):
            excluded_nodes = set(["_INPUT", "_RETURN"])

            for key, value in input_dict.items():
                model_trace.nodes[key]["value"] = value

            replayed_model = pyro.poutine.replay(model, model_trace)

            log_p = 0
            for trace_enum in iter_discrete_traces("flat", fn=replayed_model):
                trace_enum.compute_log_prob()

                for node_name, node in trace_enum.nodes.items():
                    if node_name in excluded_nodes:
                        continue

                    if node["log_prob"].ndim == 1:
                        log_p += trace_enum.nodes[node_name]["log_prob"]
                    else:
                        log_p += trace_enum.nodes[node_name]["log_prob"].sum(
                            dim=1)

            return log_p