def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
def test_compute_downstream_costs_plate_reuse(dim1, dim2):
    guide_trace = poutine.trace(plate_reuse_model_guide,
                                graph_type="dense").get_trace(
                                    include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(plate_reuse_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1,
                                                              dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes[
        'c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] -
                    guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes[
        'c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
示例#3
0
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
    guide_trace = poutine.trace(nested_model_guide2,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
示例#4
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError(
                    "https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide, graph_type="dense").get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(
                                            *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#5
0
    def _update_weights(self, model_trace, guide_trace):
        # w_t <-w_{t-1}*p(y_t|z_t) * p(z_t|z_t-1)/q(z_t)

        model_trace = prune_subsample_sites(model_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        model_trace.compute_log_prob()
        guide_trace.compute_log_prob()

        for name, guide_site in guide_trace.nodes.items():
            if guide_site["type"] == "sample":
                model_site = model_trace.nodes[name]
                log_p = model_site["log_prob"].reshape(self.num_particles,
                                                       -1).sum(-1)
                log_q = guide_site["log_prob"].reshape(self.num_particles,
                                                       -1).sum(-1)
                self.state._log_weights += log_p - log_q

        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["is_observed"]:
                log_p = site["log_prob"].reshape(self.num_particles,
                                                 -1).sum(-1)
                self.state._log_weights += log_p

        self.state._log_weights -= self.state._log_weights.max()
示例#6
0
    def _update_weights(self, model_trace, guide_trace):
        # w_t <-w_{t-1}*p(y_t|z_t) * p(z_t|z_t-1)/q(z_t)

        model_trace = prune_subsample_sites(model_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        model_trace.compute_log_prob()
        guide_trace.compute_log_prob()

        for name, guide_site in guide_trace.nodes.items():
            if guide_site["type"] == "sample":
                model_site = model_trace.nodes[name]
                log_p = model_site["log_prob"].reshape(self.num_particles,
                                                       -1).sum(-1)
                log_q = guide_site["log_prob"].reshape(self.num_particles,
                                                       -1).sum(-1)
                self.state._log_weights += log_p - log_q
                if not (self.state._log_weights.max() > -math.inf):
                    raise SMCFailed(
                        "Failed to find feasible hypothesis after site {}".
                        format(name))

        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["is_observed"]:
                log_p = site["log_prob"].reshape(self.num_particles,
                                                 -1).sum(-1)
                self.state._log_weights += log_p
                if not (self.state._log_weights.max() > -math.inf):
                    raise SMCFailed(
                        "Failed to find feasible hypothesis after site {}".
                        format(site["name"]))

        self.state._log_weights -= self.state._log_weights.max()
示例#7
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#8
0
def get_importance_trace(graph_type, max_plate_nesting, model, guide, *args,
                         **kwargs):
    """
    Returns a single trace from the guide, and the model that is run
    against it.
    """
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
        *args, **kwargs)
    model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                graph_type=graph_type).get_trace(
                                    *args, **kwargs)
    if is_validation_enabled():
        check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)

    model_trace.compute_log_prob()
    guide_trace.compute_score_parts()
    if is_validation_enabled():
        for site in model_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)
        for site in guide_trace.nodes.values():
            if site["type"] == "sample":
                check_site_shape(site, max_plate_nesting)

    return model_trace, guide_trace
示例#9
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'Trace_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            # model_trace.compute_log_prob() # TODO: no va perque no hi ha parametres de decoder
            guide_trace.compute_score_parts()

            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)
                for site in guide_trace.nodes.values():
                    if site["type"] == "sample":
                        check_site_shape(site, self.max_iarange_nesting)

            yield model_trace, guide_trace
示例#10
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide, graph_type="dense").get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       trace=guide_trace),
                                        graph_type="dense").get_trace(
                                            *args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [
                    name for name, site in guide_trace.nodes.items() if
                    site["type"] == "sample" and site["infer"].get("enumerate")
                ]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.'
                    ]))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#11
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(
                    False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized",
                               self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get(
                    "was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace,
                                    self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError(
                            "EnergyDistance requires fully reparametrized guides"
                        )
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError(
                                "EnergyDistance requires reparametrized likelihoods"
                            )

        if self.prior_scale > 0:
            model_trace.compute_log_prob(
                site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace
def test_compute_downstream_costs_plate_in_iplate(dim1):
    guide_trace = poutine.trace(
        nested_model_guide, graph_type="dense").get_trace(include_obs=False,
                                                          dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] -
                   guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] -
                   guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] -
                   guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] -
                    guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] -
                   guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = -guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
示例#14
0
def assert_ok(model, guide=None, max_plate_nesting=None, **kwargs):
    """
    Assert that enumeration runs...
    """
    with pyro_backend("pyro"):
        pyro.clear_param_store()

    if guide is None:
        guide = lambda **kwargs: None  # noqa: E731

    q_pyro, q_funsor = LifoQueue(), LifoQueue()
    q_pyro.put(Trace())
    q_funsor.put(Trace())

    while not q_pyro.empty() and not q_funsor.empty():
        with pyro_backend("pyro"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_pyro = handlers.trace(
                    handlers.queue(
                        guide,
                        q_pyro,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_pyro = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_pyro)).get_trace(**kwargs)

        with pyro_backend("contrib.funsor"):
            with handlers.enum(first_available_dim=-max_plate_nesting - 1):
                guide_tr_funsor = handlers.trace(
                    handlers.queue(
                        guide,
                        q_funsor,
                        escape_fn=iter_discrete_escape,
                        extend_fn=iter_discrete_extend,
                    )).get_trace(**kwargs)
                tr_funsor = handlers.trace(
                    handlers.replay(model,
                                    trace=guide_tr_funsor)).get_trace(**kwargs)

        # make sure all dimensions were cleaned up
        assert _DIM_STACK.local_frame is _DIM_STACK.global_frame
        assert (not _DIM_STACK.global_frame.name_to_dim
                and not _DIM_STACK.global_frame.dim_to_name)
        assert _DIM_STACK.outermost is None

        tr_pyro = prune_subsample_sites(tr_pyro.copy())
        tr_funsor = prune_subsample_sites(tr_funsor.copy())
        _check_traces(tr_pyro, tr_funsor)
示例#15
0
    def _get_matched_trace(self, model_trace, *args, **kwargs):
        """
        :param model_trace: a trace from the model
        :type model_trace: pyro.poutine.trace_struct.Trace
        :returns: guide trace with sampled values matched to model_trace
        :rtype: pyro.poutine.trace_struct.Trace

        Returns a guide trace with values at sample and observe statements
        matched to those in model_trace.

        `args` and `kwargs` are passed to the guide.
        """
        kwargs["observations"] = {}
        for node in itertools.chain(model_trace.stochastic_nodes,
                                    model_trace.observation_nodes):
            if "was_observed" in model_trace.nodes[node]["infer"]:
                model_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(self.guide,
                                                   model_trace)).get_trace(
                                                       *args, **kwargs)

        check_model_guide_match(model_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
 def forward(self, *args, **kwargs):
     samples = {}
     guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
     model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
     for site in prune_subsample_sites(model_trace).stochastic_nodes:
         samples[site] = model_trace.nodes[site]['value']
     return tuple(v for _, v in sorted(samples.items()))
示例#17
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._plates = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._plates[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support sequential pyro.plate")
示例#18
0
文件: __init__.py 项目: lewisKit/pyro
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model, default="parallel")
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
示例#19
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = - guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
示例#21
0
def _predictive(model, posterior_samples, num_samples, return_sites=(),
                return_trace=False, parallel=False, model_args=(), model_kwargs={}):
    max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1)
    model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*model_args, **model_kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape)
        reshaped_samples[name] = sample

    if return_trace:
        trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
            .get_trace(*model_args, **model_kwargs)
        return trace

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
        site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
        # non-empty return-sites
        if return_sites:
            if site in return_sites:
                return_site_shapes[site] = site_shape
        # special case (for guides): include all sites
        elif return_sites is None:
            return_site_shapes[site] = site_shape
        # default case: return sites = ()
        # include all sites not in posterior samples
        elif site not in posterior_samples:
            return_site_shapes[site] = site_shape

    # handle _RETURN site
    if return_sites is not None and '_RETURN' in return_sites:
        value = model_trace.nodes['_RETURN']['value']
        shape = (num_samples,) + value.shape if torch.is_tensor(value) else None
        return_site_shapes['_RETURN'] = shape

    if not parallel:
        return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
                                      return_site_shapes, return_trace=False)

    trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
        .get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if site == '_RETURN' and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions
def test_compute_downstream_costs_iarange_in_irange(dim1):
    guide_trace = poutine.trace(nested_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
示例#23
0
    def _guess_max_plate_nesting(self, model, guide, args, kwargs):
        """
        Guesses max_plate_nesting by running the (model,guide) pair once
        without enumeration. This optimistically assumes static model
        structure.
        """
        # Ignore validation to allow model-enumerated sites absent from the guide.
        with poutine.block():
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(model, trace=guide_trace)
            ).get_trace(*args, **kwargs)
        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        sites = [
            site
            for trace in (model_trace, guide_trace)
            for site in trace.nodes.values()
            if site["type"] == "sample"
        ]

        # Validate shapes now, since shape constraints will be weaker once
        # max_plate_nesting is changed from float('inf') to some finite value.
        # Here we know the traces are not enumerated, but later we'll need to
        # allow broadcasting of dims to the left of max_plate_nesting.
        if is_validation_enabled():
            guide_trace.compute_log_prob()
            model_trace.compute_log_prob()
            for site in sites:
                check_site_shape(site, max_plate_nesting=float("inf"))

        dims = [
            frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized
        ]
        self.max_plate_nesting = -min(dims) if dims else 0
        if self.vectorize_particles and self.num_particles > 1:
            self.max_plate_nesting += 1
        logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
示例#24
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                # This iterates over a bag of traces, for each particle.
                for scale, guide_trace in iter_discrete_traces(
                        "flat", guide, *args, **kwargs):
                    model_trace = poutine.trace(
                        poutine.replay(model, guide_trace),
                        graph_type="flat").get_trace(*args, **kwargs)

                    check_model_guide_match(model_trace, guide_trace)
                    guide_trace = prune_subsample_sites(guide_trace)
                    model_trace = prune_subsample_sites(model_trace)
                    check_enum_discrete_can_run(model_trace, guide_trace)

                    log_r = model_trace.batch_log_pdf(
                    ) - guide_trace.batch_log_pdf()
                    weight = scale / self.num_particles
                    yield weight, model_trace, guide_trace, log_r
                continue

            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            log_r = model_trace.log_pdf() - guide_trace.log_pdf()
            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace, log_r
示例#25
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
示例#26
0
文件: discrete.py 项目: ucals/pyro
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()

    return _sample_posterior_from_trace(model, enum_trace, temperature, *args,
                                        **kwargs)
示例#27
0
文件: easyguide.py 项目: www3cam/pyro
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = InitMessenger(self.init)(self.model)
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
            *args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)

        for name, site in self.prototype_trace.iter_stochastic_nodes():
            for frame in site["cond_indep_stack"]:
                if not frame.vectorized:
                    raise NotImplementedError(
                        "EasyGuide does not support sequential pyro.plate")
                self.frames[frame.name] = frame
示例#28
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError("https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#29
0
    def _get_matched_trace(model_trace, guide, args, kwargs):
        kwargs["observations"] = {}
        for node in model_trace.stochastic_nodes + model_trace.observation_nodes:
            if "was_observed" in model_trace.nodes[node]["infer"]:
                model_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(guide,
                                                   model_trace)).get_trace(
                                                       *args, **kwargs)
        check_model_guide_match(model_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
示例#30
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._plates = {}
        for name, site in self.prototype_trace.iter_stochastic_nodes():
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._plates[frame.name] = frame
                else:
                    raise NotImplementedError("AutoGuideList does not support sequential pyro.plate")
示例#31
0
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        self.prototype_trace = poutine.block(poutine.trace(self.model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoGuideList does not support pyro.irange")
示例#32
0
def test_predictive(auto_class):
    N, D = 3, 2

    class RandomLinear(nn.Linear, PyroModule):
        def __init__(self, in_features, out_features):
            super().__init__(in_features, out_features)
            self.weight = PyroSample(
                dist.Normal(0., 1.).expand([out_features,
                                            in_features]).to_event(2))
            self.bias = PyroSample(
                dist.Normal(0., 10.).expand([out_features]).to_event(1))

    class LinearRegression(PyroModule):
        def __init__(self):
            super().__init__()
            self.linear = RandomLinear(D, 1)

        def forward(self, x, y=None):
            mean = self.linear(x).squeeze(-1)
            sigma = pyro.sample("sigma", dist.LogNormal(0., 1.))
            with pyro.plate('plate', N):
                return pyro.sample('obs', dist.Normal(mean, sigma), obs=y)

    x, y = torch.randn(N, D), torch.randn(N)
    model = LinearRegression()
    guide = auto_class(model)
    # XXX: Record `y` as observed in the prototype trace
    # Is there a better pattern to follow?
    guide(x, y=y)
    # Test predictive module
    model_trace = poutine.trace(model).get_trace(x, y=None)
    predictive = Predictive(model, guide=guide, num_samples=10)
    pyro.set_rng_seed(0)
    samples = predictive(x)
    for site in prune_subsample_sites(model_trace).stochastic_nodes:
        assert site in samples
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        traced_predictive = torch.jit.trace_module(predictive, {"call": (x, )})
    f = io.BytesIO()
    torch.jit.save(traced_predictive, f)
    f.seek(0)
    predictive_deser = torch.jit.load(f)
    pyro.set_rng_seed(0)
    samples_deser = predictive_deser.call(x)
    # Note that the site values are different in the serialized guide
    assert len(samples) == len(samples_deser)
示例#33
0
    def forward(self, *args, **kwargs):
        total_samples = {}
        for i in range(self.num_samples):
            if i % 50 == 0:
                print("done with {}".format(i))
            guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(self.model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)
            for site in prune_subsample_sites(model_trace).stochastic_nodes:
                if site not in total_samples:
                    total_samples[site] = []
                total_samples[site].append(model_trace.nodes[site]["value"])
        for key in total_samples.keys():
            total_samples[key] = torch.stack(total_samples[key])

        return total_samples
示例#34
0
    def _get_matched_cross_trace(self, model_x_trace, model_z_trace,*args, **kwargs):
        kwargs["observations"] = {}
        kwargs["truth"] = {}
        for node in itertools.chain(model_x_trace.stochastic_nodes, model_x_trace.observation_nodes):
            if "was_observed" in model_x_trace.nodes[node]["infer"]:
                model_x_trace.nodes[node]["is_observed"] = True
                model_z_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_x_trace.nodes[node]["value"]
            else:
                kwargs["truth"][node] = model_x_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(self.guide,
                                                   model_z_trace)
                                    ).get_trace(*args, **kwargs)

        check_model_guide_match(model_x_trace, guide_trace)
        check_model_guide_match(model_z_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23,
                                                       include_triple, include_z1):
    guide_trace = poutine.trace(big_model_guide,
                                graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)
    model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()
    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'},
                                 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'},
                                 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
                                 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'},
                                 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'},
                                 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}}
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert(dc_nodes == expected_nodes_full_model)

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
    expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
    expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False)
    if include_inner_1:
        expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)

    if include_single:
        expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
        expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum()
        expected_b0 += model_trace.nodes['obs']['log_prob'].sum()
        if include_inner_1:
            expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum()
        assert_equal(expected_b0, dc['b0'], prec=1.0e-6)
        assert dc['b0'].size() == (5,)

    if include_inner_1:
        expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob'])
        expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'])
        expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])

        if flip_c23:
            expected_c3 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
            expected_c2 += model_trace.nodes['c3']['log_prob']
        else:
            expected_c2 += model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']
            expected_c2 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
        expected_c1 += expected_c3

        assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
        assert_equal(expected_c2, dc['c2'], prec=1.0e-6)
        assert_equal(expected_c3, dc['c3'], prec=1.0e-6)

    expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']
    expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']
    expected_d1 += model_trace.nodes['obs']['log_prob']

    expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob'])
    expected_d2 += model_trace.nodes['obs']['log_prob']

    if include_triple:
        expected_z0 = dc['a1'] + model_trace.nodes['z0']['log_prob'] - guide_trace.nodes['z0']['log_prob']
        assert_equal(expected_z0, dc['z0'], prec=1.0e-6)
    assert_equal(expected_d2, dc['d2'], prec=1.0e-6)
    assert_equal(expected_d1, dc['d1'], prec=1.0e-6)

    assert dc['b1'].size() == (2,)
    assert dc['d2'].size() == (4, 2)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
示例#36
0
def _predictive(model,
                posterior_samples,
                num_samples,
                return_sites=None,
                return_trace=False,
                parallel=False,
                model_args=(),
                model_kwargs={}):
    max_plate_nesting = _guess_max_plate_nesting(model, model_args,
                                                 model_kwargs)
    model_trace = prune_subsample_sites(
        poutine.trace(model).get_trace(*model_args, **model_kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():
        sample_shape = sample.shape[1:]
        sample = sample.reshape((num_samples, ) + (1, ) *
                                (max_plate_nesting - len(sample_shape)) +
                                sample_shape)
        reshaped_samples[name] = sample

    def _vectorized_fn(fn):
        """
        Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize
        sampling from the posterior predictive.

        :param fn: arbitrary callable containing Pyro primitives.
        :return: wrapped callable.
        """
        def wrapped_fn(*args, **kwargs):
            with pyro.plate("_num_predictive_samples",
                            num_samples,
                            dim=-max_plate_nesting - 1):
                return fn(*args, **kwargs)

        return wrapped_fn

    if return_trace:
        trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\
            .get_trace(*model_args, **model_kwargs)
        return trace

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        site_shape = (num_samples, ) + model_trace.nodes[site]['value'].shape
        if isinstance(return_sites, (list, tuple, set)):
            if site in return_sites:
                return_site_shapes[site] = site_shape
        else:
            if (return_sites is not None) or (site not in reshaped_samples):
                return_site_shapes[site] = site_shape

    # handle _RETURN site
    if isinstance(return_sites,
                  (list, tuple, set)) and '_RETURN' in return_sites:
        value = model_trace.nodes['_RETURN']['value']
        shape = (num_samples, ) + value.shape if torch.is_tensor(
            value) else None
        return_site_shapes['_RETURN'] = shape

    if not parallel:
        return _predictive_sequential(model,
                                      posterior_samples,
                                      model_args,
                                      model_kwargs,
                                      num_samples,
                                      return_site_shapes.keys(),
                                      return_trace=False)

    trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\
        .get_trace(*model_args, **model_kwargs)
    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if site == '_RETURN' and shape is None:
            predictions[site] = value
            continue
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions
示例#37
0
def predictive(model, posterior_samples, *args, **kwargs):
    """
    .. warning::
        This function is deprecated and will be removed in a future release.
        Use the :class:`~pyro.infer.predictive.Predictive` class instead.

    Run model by sampling latent parameters from `posterior_samples`, and return
    values at sample sites from the forward run. By default, only sites not contained in
    `posterior_samples` are returned. This can be modified by changing the `return_sites`
    keyword argument.

    :param model: Python callable containing Pyro primitives.
    :param dict posterior_samples: dictionary of samples from the posterior.
    :param args: model arguments.
    :param kwargs: model kwargs; and other keyword arguments (see below).

    :Keyword Arguments:
        * **num_samples** (``int``) - number of samples to draw from the predictive distribution.
          This argument has no effect if ``posterior_samples`` is non-empty, in which case, the
          leading dimension size of samples in ``posterior_samples`` is used.
        * **return_sites** (``list``) - sites to return; by default only sample sites not present
          in `posterior_samples` are returned.
        * **return_trace** (``bool``) - whether to return the full trace. Note that this is vectorized
          over `num_samples`.
        * **parallel** (``bool``) - predict in parallel by wrapping the existing model
          in an outermost `plate` messenger. Note that this requires that the model has
          all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`.

    :return: dict of samples from the predictive distribution, or a single vectorized
        `trace` (if `return_trace=True`).
    """
    warnings.warn('The `mcmc.predictive` function is deprecated and will be removed in '
                  'a future release. Use the `pyro.infer.Predictive` class instead.',
                  FutureWarning)
    num_samples = kwargs.pop('num_samples', None)
    return_sites = kwargs.pop('return_sites', None)
    return_trace = kwargs.pop('return_trace', False)
    parallel = kwargs.pop('parallel', False)

    max_plate_nesting = _guess_max_plate_nesting(model, args, kwargs)
    model_trace = prune_subsample_sites(poutine.trace(model).get_trace(*args, **kwargs))
    reshaped_samples = {}

    for name, sample in posterior_samples.items():

        batch_size, sample_shape = sample.shape[0], sample.shape[1:]

        if num_samples is None:
            num_samples = batch_size

        elif num_samples != batch_size:
            warnings.warn("Sample's leading dimension size {} is different from the "
                          "provided {} num_samples argument. Defaulting to {}."
                          .format(batch_size, num_samples, batch_size), UserWarning)
            num_samples = batch_size

        sample = sample.reshape((num_samples,) + (1,) * (max_plate_nesting - len(sample_shape)) + sample_shape)
        reshaped_samples[name] = sample

    if num_samples is None:
        raise ValueError("No sample sites in model to infer `num_samples`.")

    return_site_shapes = {}
    for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
        site_shape = (num_samples,) + model_trace.nodes[site]['value'].shape
        if return_sites:
            if site in return_sites:
                return_site_shapes[site] = site_shape
        else:
            if site not in reshaped_samples:
                return_site_shapes[site] = site_shape

    if not parallel:
        return _predictive_sequential(model, posterior_samples, args, kwargs, num_samples,
                                      return_site_shapes.keys(), return_trace)

    def _vectorized_fn(fn):
        """
        Wraps a callable inside an outermost :class:`~pyro.plate` to parallelize
        sampling from the posterior predictive.

        :param fn: arbitrary callable containing Pyro primitives.
        :return: wrapped callable.
        """

        def wrapped_fn(*args, **kwargs):
            with pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1):
                return fn(*args, **kwargs)

        return wrapped_fn

    trace = poutine.trace(poutine.condition(_vectorized_fn(model), reshaped_samples))\
        .get_trace(*args, **kwargs)

    if return_trace:
        return trace

    predictions = {}
    for site, shape in return_site_shapes.items():
        value = trace.nodes[site]['value']
        if value.numel() < reduce((lambda x, y: x * y), shape):
            predictions[site] = value.expand(shape)
        else:
            predictions[site] = value.reshape(shape)

    return predictions
示例#38
0
文件: discrete.py 项目: zyxue/pyro
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):
    # For internal use by infer_discrete.

    # Create an enumerated trace.
    with poutine.block(), EnumerateMessenger(first_available_dim):
        enum_trace = poutine.trace(model).get_trace(*args, **kwargs)
    enum_trace = prune_subsample_sites(enum_trace)
    enum_trace.compute_log_prob()
    enum_trace.pack_tensors()
    plate_to_symbol = enum_trace.plate_to_symbol

    # Collect a set of query sample sites to which the backward algorithm will propagate.
    log_probs = OrderedDict()
    sum_dims = set()
    queries = []
    for node in enum_trace.nodes.values():
        if node["type"] == "sample":
            ordinal = frozenset(plate_to_symbol[f.name]
                                for f in node["cond_indep_stack"]
                                if f.vectorized)
            log_prob = node["packed"]["log_prob"]
            log_probs.setdefault(ordinal, []).append(log_prob)
            sum_dims.update(log_prob._pyro_dims)
            for frame in node["cond_indep_stack"]:
                if frame.vectorized:
                    sum_dims.remove(plate_to_symbol[frame.name])
            # Note we mark all sample sites with require_backward to gather
            # enumerated sites and adjust cond_indep_stack of all sample sites.
            if not node["is_observed"]:
                queries.append(log_prob)
                require_backward(log_prob)

    # Run forward-backward algorithm, collecting the ordinal of each connected component.
    ring = _make_ring(temperature)
    log_probs = contract_tensor_tree(log_probs, sum_dims,
                                     ring=ring)  # run forward algorithm
    query_to_ordinal = {}
    pending = object()  # a constant value for pending queries
    for query in queries:
        query._pyro_backward_result = pending
    for ordinal, terms in log_probs.items():
        for term in terms:
            if hasattr(term, "_pyro_backward"):
                term._pyro_backward()  # run backward algorithm
        # Note: this is quadratic in number of ordinals
        for query in queries:
            if query not in query_to_ordinal and query._pyro_backward_result is not pending:
                query_to_ordinal[query] = ordinal

    # Construct a collapsed trace by gathering and adjusting cond_indep_stack.
    collapsed_trace = poutine.Trace()
    for node in enum_trace.nodes.values():
        if node["type"] == "sample" and not node["is_observed"]:
            # TODO move this into a Leaf implementation somehow
            new_node = {
                "type": "sample",
                "name": node["name"],
                "is_observed": False,
                "infer": node["infer"].copy(),
                "cond_indep_stack": node["cond_indep_stack"],
                "value": node["value"],
            }
            log_prob = node["packed"]["log_prob"]
            if hasattr(log_prob, "_pyro_backward_result"):
                # Adjust the cond_indep_stack.
                ordinal = query_to_ordinal[log_prob]
                new_node["cond_indep_stack"] = tuple(
                    f for f in node["cond_indep_stack"]
                    if not f.vectorized or plate_to_symbol[f.name] in ordinal)

                # Gather if node depended on an enumerated value.
                sample = log_prob._pyro_backward_result
                if sample is not None:
                    new_value = packed.pack(node["value"],
                                            node["infer"]["_dim_to_symbol"])
                    for index, dim in zip(jit_iter(sample),
                                          sample._pyro_sample_dims):
                        if dim in new_value._pyro_dims:
                            index._pyro_dims = sample._pyro_dims[1:]
                            new_value = packed.gather(new_value, index, dim)
                    new_node["value"] = packed.unpack(new_value,
                                                      enum_trace.symbol_to_dim)

            collapsed_trace.add_node(node["name"], **new_node)

    # Replay the model against the collapsed trace.
    with SamplePosteriorMessenger(trace=collapsed_trace):
        return model(*args, **kwargs)