示例#1
0
def test_replay_enumerate_poutine(depth, first_available_dim):
    num_particles = 2
    y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))

    def guide():
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})

    guide = poutine.enum(guide, first_available_dim=depth + first_available_dim)
    guide = poutine.trace(guide)
    guide_trace = guide.get_trace()

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
        for i in range(depth):
            pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.replay(model, trace=guide_trace)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * first_available_dim
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
示例#2
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#3
0
 def test_block_full(self):
     model_trace = poutine.trace(poutine.block(self.model)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide)).get_trace()
     for name in model_trace.nodes.keys():
         assert model_trace.nodes[name]["type"] in ("args", "return")
     for name in guide_trace.nodes.keys():
         assert guide_trace.nodes[name]["type"] in ("args", "return")
def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
    guide_trace = poutine.trace(nested_model_guide2,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
示例#6
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('mu1', 'mu2', 'sigma1', 'sigma2'):
             self.assertFalse(name in lifted_tr)
         else:
             self.assertTrue(name in lifted_tr)
示例#7
0
 def test_trace_data(self):
     tr1 = poutine.trace(
         poutine.block(self.model, expose_types=["sample"])).get_trace()
     tr2 = poutine.trace(
         poutine.condition(self.model, data=tr1)).get_trace()
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is tr1.nodes["latent2"]["value"]
示例#8
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('loc1', 'loc2', 'scale1', 'scale2'):
             assert name not in lifted_tr
         else:
             assert name in lifted_tr
示例#9
0
    def test_block_tutorial_case(self):
        model_trace = poutine.trace(self.model).get_trace()
        guide_trace = poutine.trace(
            poutine.block(self.guide, hide_types=["observe"])).get_trace()

        assert "latent1" in model_trace
        assert "latent1" in guide_trace
        assert "obs" in model_trace
        assert "obs" not in guide_trace
示例#10
0
 def test_block_full_expose(self):
     model_trace = poutine.trace(poutine.block(self.model,
                                               expose=self.model_sites)).get_trace()
     guide_trace = poutine.trace(poutine.block(self.guide,
                                               expose=self.guide_sites)).get_trace()
     for name in self.model_sites:
         assert name in model_trace
     for name in self.guide_sites:
         assert name in guide_trace
示例#11
0
def test_iarange_error_on_enter():
    def model():
        with pyro.iarange('foo', 0):
            pass

    assert len(_DIM_ALLOCATOR._stack) == 0
    with pytest.raises(ZeroDivisionError):
        poutine.trace(model)()
    assert len(_DIM_ALLOCATOR._stack) == 0, 'stack was not cleaned on error'
示例#12
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         self.assertTrue(name in lifted_tr)
         if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
             self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
         if tr.nodes[name]["type"] == "param":
             self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
                             not lifted_tr.nodes[name]["is_observed"])
示例#13
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in {'scale1', 'loc1', 'scale2', 'loc2'}:
             assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
         if tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
示例#14
0
 def _traces(self, *args, **kwargs):
     """
     Generator of weighted samples from the proposal distribution.
     """
     for i in range(self.num_samples):
         guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
         model_trace = poutine.trace(
             poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
         log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
         yield (model_trace, log_weight)
示例#15
0
 def test_replay_full_repeat(self):
     model_trace = poutine.trace(self.model).get_trace()
     ftr = poutine.trace(poutine.replay(self.model, trace=model_trace))
     tr11 = ftr.get_trace()
     tr12 = ftr.get_trace()
     tr2 = poutine.trace(poutine.replay(self.model, trace=model_trace)).get_trace()
     for name in self.full_sample_sites.keys():
         assert_equal(tr11.nodes[name]["value"], tr12.nodes[name]["value"])
         assert_equal(tr11.nodes[name]["value"], tr2.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr11.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr2.nodes[name]["value"])
示例#16
0
    def test_trace_full(self):
        guide_trace = poutine.trace(self.guide).get_trace()
        model_trace = poutine.trace(self.model).get_trace()
        for name in model_trace.nodes.keys():
            assert name in self.model_sites

        for name in guide_trace.nodes.keys():
            assert name in self.guide_sites
            assert guide_trace.nodes[name]["type"] in \
                ("args", "return", "sample", "param")
            if guide_trace.nodes[name]["type"] == "sample":
                assert not guide_trace.nodes[name]["is_observed"]
示例#17
0
 def test_block_partial_expose(self):
     model_trace = poutine.trace(
         poutine.block(self.model, expose=self.partial_sample_sites.keys())).get_trace()
     guide_trace = poutine.trace(
         poutine.block(self.guide, expose=self.partial_sample_sites.keys())).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert name in model_trace
             assert name in guide_trace
         else:
             assert name not in model_trace
             assert name not in guide_trace
示例#18
0
 def test_replay_partial(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model,
                                                guide_trace,
                                                sites=self.partial_sample_sites)).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert_equal(model_trace.nodes[name]["value"],
                          guide_trace.nodes[name]["value"])
         else:
             assert not eq(model_trace.nodes[name]["value"],
                           guide_trace.nodes[name]["value"])
示例#19
0
 def test_trace_compose(self):
     tm = poutine.trace(self.model)
     try:
         poutine.escape(tm, functools.partial(all_escape, poutine.Trace()))()
         assert False
     except NonlocalExit:
         assert "x" in tm.trace
         try:
             tem = poutine.trace(
                 poutine.escape(self.model, functools.partial(all_escape,
                                                              poutine.Trace())))
             tem()
             assert False
         except NonlocalExit:
             assert "x" not in tem.trace
示例#20
0
def test_scores(auto_class):
    def model():
        pyro.sample("z", dist.Normal(0.0, 1.0))

    guide = auto_class(model)
    guide_trace = poutine.trace(guide).get_trace()
    model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace()

    guide_trace.compute_log_prob()
    model_trace.compute_log_prob()

    assert '_auto_latent' not in model_trace.nodes
    assert model_trace.nodes['z']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['_auto_latent']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0
示例#21
0
文件: __init__.py 项目: lewisKit/pyro
    def _setup_prototype(self, *args, **kwargs):
        # run the model so we can inspect its structure
        model = config_enumerate(self.model, default="parallel")
        self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
        self.prototype_trace = prune_subsample_sites(self.prototype_trace)
        if self.master is not None:
            self.master()._check_prototype(self.prototype_trace)

        self._discrete_sites = []
        self._cond_indep_stacks = {}
        self._iaranges = {}
        for name, site in self.prototype_trace.nodes.items():
            if site["type"] != "sample" or site["is_observed"]:
                continue
            if site["infer"].get("enumerate") != "parallel":
                raise NotImplementedError('Expected sample site "{}" to be discrete and '
                                          'configured for parallel enumeration'.format(name))

            # collect discrete sample sites
            fn = site["fn"]
            Dist = type(fn)
            if Dist in (dist.Bernoulli, dist.Categorical, dist.OneHotCategorical):
                params = [("probs", fn.probs.detach().clone(), fn.arg_constraints["probs"])]
            else:
                raise NotImplementedError("{} is not supported".format(Dist.__name__))
            self._discrete_sites.append((site, Dist, params))

            # collect independence contexts
            self._cond_indep_stacks[name] = site["cond_indep_stack"]
            for frame in site["cond_indep_stack"]:
                if frame.vectorized:
                    self._iaranges[frame.name] = frame
                else:
                    raise NotImplementedError("AutoDiscreteParallel does not support pyro.irange")
示例#22
0
文件: hmc.py 项目: lewisKit/pyro
    def setup(self, *args, **kwargs):
        self._args = args
        self._kwargs = kwargs
        # set the trace prototype to inter-convert between trace object
        # and dict object used by the integrator
        trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        self._prototype_trace = trace
        if self._automatic_transform_enabled:
            self.transforms = {}
        for name, node in sorted(trace.iter_stochastic_nodes(), key=lambda x: x[0]):
            site_value = node["value"]
            if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
                self.transforms[name] = biject_to(node["fn"].support).inv
                site_value = self.transforms[name](node["value"])
            r_loc = site_value.new_zeros(site_value.shape)
            r_scale = site_value.new_ones(site_value.shape)
            self._r_dist[name] = dist.Normal(loc=r_loc, scale=r_scale)
        self._validate_trace(trace)

        if self.adapt_step_size:
            self._adapt_phase = True
            z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()}
            for name, transform in self.transforms.items():
                z[name] = transform(z[name])
            self.step_size = self._find_reasonable_step_size(z)
            self.num_steps = max(1, int(self.trajectory_length / self.step_size))
            # make prox-center for Dual Averaging scheme
            loc = math.log(10 * self.step_size)
            self._adapted_scheme = DualAveraging(prox_center=loc)
示例#23
0
文件: hmc.py 项目: lewisKit/pyro
 def _get_trace(self, z):
     z_trace = self._prototype_trace
     for name, value in z.items():
         z_trace.nodes[name]["value"] = value
     trace_poutine = poutine.trace(poutine.replay(self.model, trace=z_trace))
     trace_poutine(*self._args, **self._kwargs)
     return trace_poutine.trace
示例#24
0
 def _traces(self, *args, **kwargs):
     if not self.posterior.exec_traces:
         self.posterior.run(*args, **kwargs)
     for _ in range(self.num_samples):
         model_trace = self.posterior()
         replayed_trace = poutine.trace(poutine.replay(self.model, model_trace)).get_trace(*args, **kwargs)
         yield (replayed_trace, 0.)
示例#25
0
def test_decorator_interface_primitives():

    @poutine.trace
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "flat"

    @poutine.trace(graph_type="dense")
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "dense"

    tr2 = poutine.trace(poutine.replay(model, trace=tr)).get_trace()

    assert_equal(tr2.nodes["a"]["value"], tr.nodes["a"]["value"])
示例#26
0
 def test_condition(self):
     data = {"latent2": torch.randn(2)}
     tr2 = poutine.trace(poutine.condition(self.model, data=data)).get_trace()
     assert "latent2" in tr2
     assert tr2.nodes["latent2"]["type"] == "sample" and \
         tr2.nodes["latent2"]["is_observed"]
     assert tr2.nodes["latent2"]["value"] is data["latent2"]
示例#27
0
def test_optimizers(factory):
    optim = factory()

    def model(loc, cov):
        x = pyro.param("x", torch.randn(2))
        y = pyro.param("y", torch.randn(3, 2))
        z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
        pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
        with pyro.iarange("y_iarange", 3):
            pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
        with pyro.iarange("z_iarange", 4):
            pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)

    loc = torch.tensor([-0.5, 0.5])
    cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]])
    for step in range(100):
        tr = poutine.trace(model).get_trace(loc, cov)
        loss = -tr.log_prob_sum()
        params = {name: pyro.param(name).unconstrained() for name in ["x", "y", "z"]}
        optim.step(loss, params)

    for name in ["x", "y", "z"]:
        actual = pyro.param(name)
        expected = loc.expand(actual.shape)
        assert_equal(actual, expected, prec=1e-2,
                     msg='{} in correct: {} vs {}'.format(name, actual, expected))
示例#28
0
 def test_random_module(self):
     pyro.clear_param_store()
     lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace()
     for name in lifted_tr.nodes.keys():
         if lifted_tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
示例#29
0
    def test_infer_config_sample(self):
        cfg_model = poutine.infer_config(self.model, config_fn=self.config_fn)

        tr = poutine.trace(cfg_model).get_trace()

        assert tr.nodes["a"]["infer"] == {"enumerate": "parallel", "blah": True}
        assert tr.nodes["b"]["infer"] == {"blah": True}
        assert tr.nodes["p"]["infer"] == {}
示例#30
0
 def _test_scale_factor(batch_size_outer, batch_size_inner, expected):
     trace = poutine.trace(self.model, graph_type="dense").get_trace(batch_size_outer=batch_size_outer,
                                                                     batch_size_inner=batch_size_inner)
     scale_factors = []
     for node in ['z_0_0', 'z_0_1', 'z_1_0', 'z_1_1']:
         if node in trace:
             scale_factors.append(trace.nodes[node]['scale'])
     assert scale_factors == expected
示例#31
0
 def _traces(self, *args, **kwargs):
     q = queue.PriorityQueue()
     # add a little bit of noise to the priority to break ties...
     q.put((torch.zeros(1).item() - torch.rand(1).item() * 1e-2, poutine.Trace()))
     q_fn = pqueue(self.model, queue=q)
     for i in range(self.num_samples):
         if q.empty():
             # num_samples was too large!
             break
         tr = poutine.trace(q_fn).get_trace(*args, **kwargs)  # XXX should block
         yield tr, tr.log_prob_sum()
示例#32
0
def test_scores(auto_class):
    def model():
        if auto_class is AutoIAFNormal:
            pyro.sample("z", dist.Normal(0.0, 1.0).expand([10]))
        else:
            pyro.sample("z", dist.Normal(0.0, 1.0))

    guide = auto_class(model)
    guide_trace = poutine.trace(guide).get_trace()
    model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace()

    guide_trace.compute_log_prob()
    model_trace.compute_log_prob()

    prefix = auto_class.__name__
    assert '_{}_latent'.format(prefix) not in model_trace.nodes
    assert model_trace.nodes['z']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['_{}_latent'.format(
        prefix)]['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0
示例#33
0
    def init(self, state):
        with poutine.trace() as tr:
            params = self.model.global_model()
        for name, site in tr.trace.nodes.items():
            if site["type"] == "sample":
                state[name] = site["value"]

        self.t = 0
        state.update(self.model.initialize(params))
        self.step(
            state)  # Take one step since model.initialize is deterministic.
示例#34
0
def WAIC(model, x, y, out_var_nm, num_samples=100):
    p = torch.zeros((num_samples, len(y)))
    # Get log probability samples
    for i in range(num_samples):
        tr = poutine.trace(poutine.condition(model, data=model.guide())).get_trace(x)
        dist = tr.nodes[out_var_nm]["fn"]
        p[i] = dist.log_prob(y).detach()
    pmax = p.max(axis=0).values
    lppd = pmax + (p - pmax).exp().mean(axis=0).log() # numerically stable version
    penalty = p.var(axis=0)
    return -2*(lppd - penalty)
示例#35
0
def aic_num_parameters(model, guide=None):
    """
    hacky AIC param count that includes all parameters in the model and guide
    """
    with poutine.block(), poutine.trace(param_only=True) as param_capture:
        model()
        if guide is not None:
            guide()

    return sum(node["value"].numel()
               for node in param_capture.trace.nodes.values())
示例#36
0
def score_latent(zs, ys):
    model = HarmonicModel()
    with poutine.trace() as trace:
        with poutine.condition(
                data={"z_{}".format(t): z
                      for t, z in enumerate(zs)}):
            model.init()
            for y in ys[1:]:
                model.step(y)

    return trace.trace.log_prob_sum()
示例#37
0
    def test_infer_config_sample(self):
        cfg_model = poutine.infer_config(self.model, config_fn=self.config_fn)

        tr = poutine.trace(cfg_model).get_trace()

        assert tr.nodes["a"]["infer"] == {
            "enumerate": "parallel",
            "blah": True
        }
        assert tr.nodes["b"]["infer"] == {"blah": True}
        assert tr.nodes["p"]["infer"] == {}
示例#38
0
def mc_elbo_with_l2(model, guide, data1, data2, lam=0.01):
    guide_trace = poutine.trace(guide).get_trace(data1, data2)
    model_trace = poutine.trace(poutine.replay(model,
                                               trace=guide_trace)).get_trace(
                                                   data1, data2)
    logp = model_trace.log_prob_sum()
    logq = guide_trace.log_prob_sum()
    penalty = 0
    # 加入正则化
    for node in model_trace.nodes.values():
        if node["type"] == "param":
            penalty = penalty + lam * torch.sum(torch.pow(node["value"], 2))

    for node in guide_trace.nodes.values():
        if node["type"] == "param":
            penalty = penalty + lam * torch.sum(torch.pow(node["value"], 2))


#
    return logq - logp + penalty
示例#39
0
 def test_block_expose_fn(self):
     model_trace = poutine.trace(
         poutine.block(
             self.model,
             expose_fn=lambda msg: "latent" in msg["name"],
             hide=["latent1"],
         )
     ).get_trace()
     assert "latent1" in model_trace
     assert "latent2" in model_trace
     assert "obs" not in model_trace
示例#40
0
 def test_random_module_prior_dict(self):
     pyro.clear_param_store()
     lifted_nn = pyro.random_module("name", self.model, prior=self.nn_prior)
     lifted_tr = poutine.trace(lifted_nn).get_trace()
     for key_name in lifted_tr.nodes.keys():
         name = pyro.params.user_param_name(key_name)
         if name in {"fc.weight", "fc.prior"}:
             dist_name = name[3:]
             assert dist_name + "_prior" == lifted_tr.nodes[key_name]["fn"].__name__
             assert lifted_tr.nodes[key_name]["type"] == "sample"
             assert not lifted_tr.nodes[key_name]["is_observed"]
示例#41
0
 def test_trace_compose(self):
     tm = poutine.trace(self.model)
     try:
         poutine.escape(
             tm, escape_fn=functools.partial(all_escape, poutine.Trace())
         )()
         assert False
     except NonlocalExit:
         assert "x" in tm.trace
         try:
             tem = poutine.trace(
                 poutine.escape(
                     self.model,
                     escape_fn=functools.partial(all_escape, poutine.Trace()),
                 )
             )
             tem()
             assert False
         except NonlocalExit:
             assert "x" not in tem.trace
示例#42
0
def test_pickling(wrapper):
    wrapped = wrapper(_model)
    buffer = io.BytesIO()
    # default protocol cannot serialize torch.Size objects (see https://github.com/pytorch/pytorch/issues/20823)
    torch.save(wrapped, buffer, pickle_protocol=pickle.HIGHEST_PROTOCOL)
    buffer.seek(0)
    deserialized = torch.load(buffer)
    obs = torch.tensor(0.5)
    pyro.set_rng_seed(0)
    actual_trace = poutine.trace(deserialized).get_trace(obs)
    pyro.set_rng_seed(0)
    expected_trace = poutine.trace(wrapped).get_trace(obs)
    assert tuple(actual_trace) == tuple(expected_trace.nodes)
    assert_close(
        [actual_trace.nodes[site]["value"] for site in actual_trace.stochastic_nodes],
        [
            expected_trace.nodes[site]["value"]
            for site in expected_trace.stochastic_nodes
        ],
    )
示例#43
0
    def __init__(self,
                 model,
                 data,
                 covariates=None,
                 *,
                 num_warmup=1000,
                 num_samples=1000,
                 num_chains=1,
                 time_reparam=None,
                 dense_mass=False,
                 jit_compile=False,
                 max_tree_depth=10):
        assert data.size(-2) == covariates.size(-2)
        super().__init__()
        if time_reparam == "haar":
            model = poutine.reparam(model, time_reparam_haar)
        elif time_reparam == "dct":
            model = poutine.reparam(model, time_reparam_dct)
        elif time_reparam is not None:
            raise ValueError("unknown time_reparam: {}".format(time_reparam))
        self.model = model
        max_plate_nesting = _guess_max_plate_nesting(model, (data, covariates),
                                                     {})
        self.max_plate_nesting = max(max_plate_nesting,
                                     1)  # force a time plate

        kernel = NUTS(model,
                      full_mass=dense_mass,
                      jit_compile=jit_compile,
                      ignore_jit_warnings=True,
                      max_tree_depth=max_tree_depth,
                      max_plate_nesting=max_plate_nesting)
        mcmc = MCMC(kernel,
                    warmup_steps=num_warmup,
                    num_samples=num_samples,
                    num_chains=num_chains)
        mcmc.run(data, covariates)
        # conditions to compute rhat
        if (num_chains == 1 and num_samples >= 4) or (num_chains > 1
                                                      and num_samples >= 2):
            mcmc.summary()

        # inspect the model with particles plate = 1, so that we can reshape samples to
        # add any missing plate dim in front.
        with poutine.trace() as tr:
            with pyro.plate("particles", 1, dim=-self.max_plate_nesting - 1):
                model(data, covariates)

        self._trace = tr.trace
        self._samples = mcmc.get_samples()
        self._num_samples = num_samples * num_chains
        for name, node in list(self._trace.nodes.items()):
            if name not in self._samples:
                del self._trace.nodes[name]
示例#44
0
    def _guess_max_plate_nesting(self, model, guide, args, kwargs):
        """
        Guesses max_plate_nesting by running the (model,guide) pair once
        without enumeration. This optimistically assumes static model
        structure.
        """
        # Ignore validation to allow model-enumerated sites absent from the guide.
        with poutine.block():
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(model, trace=guide_trace)
            ).get_trace(*args, **kwargs)
        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        sites = [
            site
            for trace in (model_trace, guide_trace)
            for site in trace.nodes.values()
            if site["type"] == "sample"
        ]

        # Validate shapes now, since shape constraints will be weaker once
        # max_plate_nesting is changed from float('inf') to some finite value.
        # Here we know the traces are not enumerated, but later we'll need to
        # allow broadcasting of dims to the left of max_plate_nesting.
        if is_validation_enabled():
            guide_trace.compute_log_prob()
            model_trace.compute_log_prob()
            for site in sites:
                check_site_shape(site, max_plate_nesting=float("inf"))

        dims = [
            frame.dim
            for site in sites
            for frame in site["cond_indep_stack"]
            if frame.vectorized
        ]
        self.max_plate_nesting = -min(dims) if dims else 0
        if self.vectorize_particles and self.num_particles > 1:
            self.max_plate_nesting += 1
        logging.info("Guessed max_plate_nesting = {}".format(self.max_plate_nesting))
示例#45
0
    def loss_fn(design, num_particles, evaluation=False, **kwargs):
        N, M = num_particles
        expanded_design = lexpand(design, N)

        # Sample from p(y, theta | d)
        trace = poutine.trace(model).get_trace(expanded_design)
        y_dict = {
            l: lexpand(trace.nodes[l]["value"], M)
            for l in observation_labels
        }

        # Sample M times from q(theta | y, d) for each y
        reexpanded_design = lexpand(expanded_design, M)
        conditional_guide = pyro.condition(guide, data=y_dict)
        guide_trace = poutine.trace(conditional_guide).get_trace(
            y_dict, reexpanded_design, observation_labels, target_labels)
        theta_y_dict = {
            l: guide_trace.nodes[l]["value"]
            for l in target_labels
        }
        theta_y_dict.update(y_dict)
        guide_trace.compute_log_prob()

        # Re-run that through the model to compute the joint
        modelp = pyro.condition(model, data=theta_y_dict)
        model_trace = poutine.trace(modelp).get_trace(reexpanded_design)
        model_trace.compute_log_prob()

        terms = -sum(guide_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"] for l in target_labels)
        terms += sum(model_trace.nodes[l]["log_prob"]
                     for l in observation_labels)
        terms = -terms.logsumexp(0) + math.log(M)

        # At eval time, add p(y | theta, d) terms
        if evaluation:
            trace.compute_log_prob()
            terms += sum(trace.nodes[l]["log_prob"]
                         for l in observation_labels)

        return _safe_mean_terms(terms)
示例#46
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError("https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
示例#47
0
文件: util.py 项目: pyro-ppl/pyro
 def _potential_fn(self, params):
     params_constrained = {k: self.transforms[k].inv(v) for k, v in params.items()}
     cond_model = poutine.condition(self.model, params_constrained)
     model_trace = poutine.trace(cond_model).get_trace(
         *self.model_args, **self.model_kwargs
     )
     log_joint = self.trace_prob_evaluator.log_prob(model_trace)
     for name, t in self.transforms.items():
         log_joint = log_joint - torch.sum(
             t.log_abs_det_jacobian(params_constrained[name], params[name])
         )
     return -log_joint
示例#48
0
 def _get_initial_trace():
     guide = AutoDelta(
         poutine.block(
             model,
             expose_fn=lambda msg: not msg["name"].startswith("x") and
             not msg["name"].startswith("y"),
         ))
     elbo = TraceEnum_ELBO(max_plate_nesting=1)
     svi = SVI(model, guide, optim.Adam({"lr": 0.01}), elbo)
     for _ in range(100):
         svi.step(data)
     return poutine.trace(guide).get_trace(data)
示例#49
0
    def _sample_from_joint(self, *args, **kwargs):
        """
        :returns: a sample from the joint distribution over unobserved and
            observed variables
        :rtype: pyro.poutine.trace_struct.Trace

        Returns a trace of the model without conditioning on any observations.

        Arguments are passed directly to the model.
        """
        unconditioned_model = pyro.poutine.uncondition(self.model)
        return poutine.trace(unconditioned_model).get_trace(*args, **kwargs)
示例#50
0
    def laplace_approximation(self, *args, **kwargs):
        """
        Returns a :class:`AutoMultivariateNormal` instance whose posterior's `loc` and
        `scale_tril` are given by Laplace approximation.
        """
        guide_trace = poutine.trace(self).get_trace(*args, **kwargs)
        model_trace = poutine.trace(
            poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
        loss = guide_trace.log_prob_sum() - model_trace.log_prob_sum()

        H = hessian(loss, self.loc)
        cov = H.inverse()
        loc = self.loc
        scale_tril = cov.cholesky()

        gaussian_guide = AutoMultivariateNormal(self.model)
        gaussian_guide._setup_prototype(*args, **kwargs)
        # Set loc, scale_tril parameters as computed above.
        gaussian_guide.loc = loc
        gaussian_guide.scale_tril = scale_tril
        return gaussian_guide
示例#51
0
 def test_stack_success(self):
     data1 = {"latent1": torch.randn(2)}
     data2 = {"latent2": torch.randn(2)}
     tr = poutine.trace(
         poutine.condition(poutine.condition(self.model, data=data1),
                           data=data2)).get_trace()
     assert tr.nodes["latent1"]["type"] == "sample" and \
         tr.nodes["latent1"]["is_observed"]
     assert tr.nodes["latent1"]["value"] is data1["latent1"]
     assert tr.nodes["latent2"]["type"] == "sample" and \
         tr.nodes["latent2"]["is_observed"]
     assert tr.nodes["latent2"]["value"] is data2["latent2"]
示例#52
0
 def step(self, *args, **kwargs):
     with poutine.trace(param_only=True) as param_capture:
         loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     params = []
     for site in param_capture.trace.nodes.values():
         param = site["value"].unconstrained()
         if site.get('free') is not None:
             param.grad = site['free'] * param.grad
         params.append(param)
     self.optim(params)
     pyro.infer.util.zero_grads(params)
     return torch_item(loss)
示例#53
0
def test_replay(model, subsample_size):
    pyro.set_rng_seed(0)

    traced_model = poutine.trace(model)
    original = traced_model(subsample_size)

    replayed = poutine.replay(model, trace=traced_model.trace)(subsample_size)
    assert replayed == original

    if subsample_size < 20:
        different = traced_model(subsample_size)
        assert different != original
示例#54
0
    def compute_l2_term(self, *args, **kwargs):
        tr = poutine.trace(self.guide).get_trace(*args, **kwargs)
        params = {
            name: site['value'].unconstrained()
            for name, site in tr.nodes.items() if site['type'] == 'param'
        }

        # only account for theta and alpha for now
        l2_term = params['theta_loc'].square().mean(dim = 1).sum() + \
                  params['alpha_loc'].square().mean(dim = 1).sum()
        l2_term = self.l2_factor * l2_term
        return l2_term
示例#55
0
 def sample(
     self,
     num_steps: int,
     num_samples: Optional[int] = None,
     x: Optional[Tensor] = None,
 ) -> Tensor:
     num_samples = x.size(0) if x is not None else (num_samples or 1)
     z = self.guide(*pack_sequences(list(x))) if x is not None else self.qz0
     nodes = trace(self._model).get_trace(z, num_samples, num_steps).nodes
     x_pred = torch.stack(
         [nodes[f"x_{i+1}"]["value"] for i in range(num_steps)], 1)
     return x_pred
示例#56
0
 def test_graph_structure(self):
     tracegraph = poutine.trace(self.model, graph_type="dense").get_trace()
     # Ignore structure on plate_* nodes.
     actual_nodes = set(n for n in tracegraph.nodes if not n.startswith("plate_"))
     actual_edges = set(
         (n1, n2)
         for n1, n2 in tracegraph.edges
         if not n1.startswith("plate_")
         if not n2.startswith("plate_")
     )
     assert actual_nodes == self.expected_nodes
     assert actual_edges == self.expected_edges
示例#57
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                # This iterates over a bag of traces, for each particle.
                for scale, guide_trace in iter_discrete_traces(
                        "flat", guide, *args, **kwargs):
                    model_trace = poutine.trace(
                        poutine.replay(model, guide_trace),
                        graph_type="flat").get_trace(*args, **kwargs)

                    check_model_guide_match(model_trace, guide_trace)
                    guide_trace = prune_subsample_sites(guide_trace)
                    model_trace = prune_subsample_sites(model_trace)
                    check_enum_discrete_can_run(model_trace, guide_trace)

                    log_r = model_trace.batch_log_pdf(
                    ) - guide_trace.batch_log_pdf()
                    weight = scale / self.num_particles
                    yield weight, model_trace, guide_trace, log_r
                continue

            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            log_r = model_trace.log_pdf() - guide_trace.log_pdf()
            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace, log_r
示例#58
0
def test_trace_smoke(Model, batch_shape, t_obs, obs_dim, cov_dim):
    model = Model()
    data = torch.randn(batch_shape + (t_obs, obs_dim))
    covariates = torch.randn(batch_shape + (t_obs, cov_dim))
    forecaster = Forecaster(model, data, covariates, num_steps=2, log_every=1)
    hmc_forecaster = HMCForecaster(
        model,
        data,
        covariates,
        max_tree_depth=1,
        num_warmup=1,
        num_samples=1,
        jit_compile=False,
    )

    # This is the desired syntax for recording posterior latent samples.
    num_samples = 5
    with poutine.trace() as svi:
        forecaster(data, covariates, num_samples)
    with poutine.trace() as hmc:
        hmc_forecaster(data, covariates, num_samples)

    # Check they match the equivalent poutine version.
    dim = -1 - forecaster.max_plate_nesting
    with torch.no_grad():
        with poutine.trace() as tr:
            with pyro.plate("particles", num_samples, dim=dim):
                forecaster.guide(data, covariates)
        with poutine.trace() as expected:
            with poutine.replay(trace=tr.trace):
                with pyro.plate("particles", num_samples, dim=dim):
                    model(data, covariates)
    for actual, engine in zip([svi, hmc], ["svi", "hmc"]):
        for name, site in expected.trace.nodes.items():
            expected_fn = site["fn"]
            if type(expected_fn).__name__ == "_Subsample":
                continue
            actual_fn = actual.trace.nodes[name]["fn"]
            assert name in actual.trace.nodes, engine
            assert actual_fn.event_shape == expected_fn.event_shape, engine
示例#59
0
def test_moments(dist_type, centered, shape):
    loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
    scale = torch.empty(shape).uniform_(0.5, 1.5).requires_grad_()
    if isinstance(centered, torch.Tensor):
        centered = centered.expand(shape)

    def model():
        with pyro.plate_stack("plates", shape):
            with pyro.plate("particles", 200000):
                if "dist_type" == "Normal":
                    pyro.sample("x", dist.Normal(loc, scale))
                elif "dist_type" == "StudentT":
                    pyro.sample("x", dist.StudentT(10.0, loc, scale))
                else:
                    pyro.sample("x", dist.AsymmetricLaplace(loc, scale, 1.5))

    value = poutine.trace(model).get_trace().nodes["x"]["value"]
    expected_probe = get_moments(value)

    reparam = LocScaleReparam(centered)
    reparam_model = poutine.reparam(model, {"x": reparam})
    value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"]
    actual_probe = get_moments(value)

    if not is_identically_one(centered):
        if "dist_type" == "Normal":
            assert reparam.shape_params == ()
        elif "dist_type" == "StudentT":
            assert reparam.shape_params == ("df", )
        else:
            assert reparam.shape_params == ("asymmetry", )

    assert_close(actual_probe, expected_probe, atol=0.1, rtol=0.05)

    for actual_m, expected_m in zip(actual_probe, expected_probe):
        expected_grads = grad(expected_m.sum(), [loc, scale],
                              retain_graph=True)
        actual_grads = grad(actual_m.sum(), [loc, scale], retain_graph=True)
        assert_close(actual_grads[0], expected_grads[0], atol=0.1, rtol=0.05)
        assert_close(actual_grads[1], expected_grads[1], atol=0.1, rtol=0.05)
示例#60
0
文件: eig.py 项目: youisbaby/pyro
def lfire_eig(model, design, observation_labels, target_labels,
              num_y_samples, num_theta_samples, num_steps, classifier, optim, return_history=False,
              final_design=None, final_num_samples=None):
    """Estimates the EIG using the method of Likelihood-Free Inference by Ratio Estimation (LFIRE) as in [1].
    LFIRE is run separately for several samples of :math:`\\theta`.

    [1] Kleinegesse, Steven, and Michael Gutmann. "Efficient Bayesian Experimental Design for Implicit Models."
    arXiv preprint arXiv:1810.09912 (2018).

    :param function model: A pyro model accepting `design` as only argument.
    :param torch.Tensor design: Tensor representation of design
    :param list observation_labels: A subset of the sample sites
        present in `model`. These sites are regarded as future observations
        and other sites are regarded as latent variables over which a
        posterior is to be inferred.
    :param list target_labels: A subset of the sample sites over which the posterior
        entropy is to be measured.
    :param int num_y_samples: Number of samples to take in :math:`y` for each :math:`\\theta`.
    :param: int num_theta_samples: Number of initial samples in :math:`\\theta` to take. The likelihood ratio
                                   is estimated by LFIRE for each sample.
    :param int num_steps: Number of optimization steps.
    :param function classifier: a Pytorch or Pyro classifier used to distinguish between samples of :math:`y` under
                                :math:`p(y|d)` and samples under :math:`p(y|\\theta,d)` for some :math:`\\theta`.
    :param pyro.optim.Optim optim: Optimiser to use.
    :param bool return_history: If `True`, also returns a tensor giving the loss function
        at each step of the optimization.
    :param torch.Tensor final_design: The final design tensor to evaluate at. If `None`, uses
        `design`.
    :param int final_num_samples: The number of samples to use at the final evaluation, If `None,
        uses `num_samples`.
    :return: EIG estimate, optionally includes full optimization history
    :rtype: torch.Tensor or tuple
    """
    if isinstance(observation_labels, str):
        observation_labels = [observation_labels]
    if isinstance(target_labels, str):
        target_labels = [target_labels]

    # Take N samples of the model
    expanded_design = lexpand(design, num_theta_samples)
    trace = poutine.trace(model).get_trace(expanded_design)

    theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}
    cond_model = pyro.condition(model, data=theta_dict)

    loss = _lfire_loss(model, cond_model, classifier, observation_labels, target_labels)
    out = opt_eig_ape_loss(expanded_design, loss, num_y_samples, num_steps, optim, return_history,
                           final_design, final_num_samples)
    if return_history:
        return out[0], out[1].sum(0) / num_theta_samples
    else:
        return out.sum(0) / num_theta_samples