Beispiel #1
0
def test_model_guide_mismatch(behavior, model_size, guide_size, model):
    model = poutine.trace(model)
    expected_ind = model(guide_size)
    if behavior == "ok":
        actual_ind = poutine.replay(model, trace=model.trace)(model_size)
        assert actual_ind == expected_ind
    else:
        with pytest.raises(ValueError):
            poutine.replay(model, trace=model.trace)(model_size)
Beispiel #2
0
 def test_replay_full_repeat(self):
     model_trace = poutine.trace(self.model).get_trace()
     ftr = poutine.trace(poutine.replay(self.model, trace=model_trace))
     tr11 = ftr.get_trace()
     tr12 = ftr.get_trace()
     tr2 = poutine.trace(poutine.replay(self.model, trace=model_trace)).get_trace()
     for name in self.full_sample_sites.keys():
         assert_equal(tr11.nodes[name]["value"], tr12.nodes[name]["value"])
         assert_equal(tr11.nodes[name]["value"], tr2.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr11.nodes[name]["value"])
         assert_equal(model_trace.nodes[name]["value"], tr2.nodes[name]["value"])
Beispiel #3
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator
        """

        for i in range(self.num_particles):
            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)
            if is_validation_enabled():
                check_model_guide_match(model_trace, guide_trace)
                enumerated_sites = [name for name, site in guide_trace.nodes.items()
                                    if site["type"] == "sample" and site["infer"].get("enumerate")]
                if enumerated_sites:
                    warnings.warn('\n'.join([
                        'TraceGraph_ELBO found sample sites configured for enumeration:'
                        ', '.join(enumerated_sites),
                        'If you want to enumerate sites, you need to use TraceEnum_ELBO instead.']))

            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
def test_compute_downstream_costs_irange_in_iarange(dim1, dim2):
    guide_trace = poutine.trace(nested_model_guide2,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c']['log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
    guide_trace = poutine.trace(iarange_reuse_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
    model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
    expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
    expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']
    assert_equal(expected_c1, dc['c1'])
Beispiel #6
0
def test_replay_enumerate_poutine(depth, first_available_dim):
    num_particles = 2
    y_dist = Categorical(torch.tensor([0.5, 0.25, 0.25]))

    def guide():
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})

    guide = poutine.enum(guide, first_available_dim=depth + first_available_dim)
    guide = poutine.trace(guide)
    guide_trace = guide.get_trace()

    def model():
        pyro.sample("x", Bernoulli(0.5))
        for i in range(depth):
            pyro.sample("a_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})
        pyro.sample("y", y_dist, infer={"enumerate": "parallel"})
        for i in range(depth):
            pyro.sample("b_{}".format(i), Bernoulli(0.5), infer={"enumerate": "parallel"})

    model = poutine.enum(model, first_available_dim=first_available_dim)
    model = poutine.replay(model, trace=guide_trace)
    model = poutine.trace(model)

    for i in range(num_particles):
        tr = model.get_trace()
        assert tr.nodes["y"]["value"] is guide_trace.nodes["y"]["value"]
        tr.compute_log_prob()
        log_prob = sum(site["log_prob"] for name, site in tr.iter_stochastic_nodes())
        actual_shape = log_prob.shape
        expected_shape = (2,) * depth + (3,) + (2,) * depth + (1,) * first_available_dim
        assert actual_shape == expected_shape, 'error on iteration {}'.format(i)
Beispiel #7
0
def test_decorator_interface_primitives():

    @poutine.trace
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "flat"

    @poutine.trace(graph_type="dense")
    def model():
        pyro.param("p", torch.zeros(1, requires_grad=True))
        pyro.sample("a", Bernoulli(torch.tensor([0.5])),
                    infer={"enumerate": "parallel"})
        pyro.sample("b", Bernoulli(torch.tensor([0.5])))

    tr = model.get_trace()
    assert isinstance(tr, poutine.Trace)
    assert tr.graph_type == "dense"

    tr2 = poutine.trace(poutine.replay(model, trace=tr)).get_trace()

    assert_equal(tr2.nodes["a"]["value"], tr.nodes["a"]["value"])
Beispiel #8
0
 def _traces(self, *args, **kwargs):
     if not self.posterior.exec_traces:
         self.posterior.run(*args, **kwargs)
     for _ in range(self.num_samples):
         model_trace = self.posterior()
         replayed_trace = poutine.trace(poutine.replay(self.model, model_trace)).get_trace(*args, **kwargs)
         yield (replayed_trace, 0.)
Beispiel #9
0
 def _get_trace(self, z):
     z_trace = self._prototype_trace
     for name, value in z.items():
         z_trace.nodes[name]["value"] = value
     trace_poutine = poutine.trace(poutine.replay(self.model, trace=z_trace))
     trace_poutine(*self._args, **self._kwargs)
     return trace_poutine.trace
Beispiel #10
0
 def _traces(self, *args, **kwargs):
     """
     Generator of weighted samples from the proposal distribution.
     """
     for i in range(self.num_samples):
         guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
         model_trace = poutine.trace(
             poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
         log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
         yield (model_trace, log_weight)
Beispiel #11
0
            def compiled(unconstrained_params, *args):
                self = weakself()
                constrained_params = {}
                for name, unconstrained_param in zip(self._param_names, unconstrained_params):
                    constrained_param = pyro.param(name)  # assume param has been initialized
                    assert constrained_param.unconstrained() is unconstrained_param
                    constrained_params[name] = constrained_param

                return poutine.replay(
                    self.fn, params=constrained_params)(*args, **kwargs)
Beispiel #12
0
    def ite(self, x, num_samples=None, batch_size=None):
        r"""
        Computes Individual Treatment Effect for a batch of data ``x``.

        .. math::

            ITE(x) = \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=1) \bigr]
                   - \mathbb E\bigl[ \mathbf y \mid \mathbf X=x, do(\mathbf t=0) \bigr]

        This has complexity ``O(len(x) * num_samples ** 2)``.

        :param ~torch.Tensor x: A batch of data.
        :param int num_samples: The number of monte carlo samples.
            Defaults to ``self.num_samples`` which defaults to ``100``.
        :param int batch_size: Batch size. Defaults to ``len(x)``.
        :return: A ``len(x)``-sized tensor of estimated effects.
        :rtype: ~torch.Tensor
        """
        if num_samples is None:
            num_samples = self.num_samples
        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim

        dataloader = [x] if batch_size is None else DataLoader(
            x, batch_size=batch_size)
        logger.info("Evaluating {} minibatches".format(len(dataloader)))
        result = []
        for x in dataloader:
            x = self.whiten(x)
            with pyro.plate("num_particles", num_samples, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                with poutine.do(data=dict(t=torch.zeros(()))):
                    y0 = poutine.replay(self.model.y_mean, tr.trace)(x)
                with poutine.do(data=dict(t=torch.ones(()))):
                    y1 = poutine.replay(self.model.y_mean, tr.trace)(x)
            ite = (y1 - y0).mean(0)
            if not torch._C._get_tracing_state():
                logger.debug("batch ate = {:0.6g}".format(ite.mean()))
            result.append(ite)
        return torch.cat(result)
Beispiel #13
0
def relbo(model, guide, *args, **kwargs):
    approximation = kwargs.pop('approximation')
    traced_guide = trace(guide)
    elbo = pyro.infer.Trace_ELBO(num_particles=NUM_PARTICLES)
    loss_fn = elbo.differentiable_loss(model, traced_guide, *args, **kwargs)
    guide_trace = traced_guide.trace
    replayed_approximation = trace(replay(block(approximation,
                                                expose=['beta_latent', 'z']),
                                          guide_trace))
    approximation_trace = replayed_approximation.get_trace(*args, **kwargs)
    relbo = -loss_fn - approximation_trace.log_prob_sum()
    return -relbo
Beispiel #14
0
 def test_replay_partial(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model,
                                                guide_trace,
                                                sites=self.partial_sample_sites)).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert_equal(model_trace.nodes[name]["value"],
                          guide_trace.nodes[name]["value"])
         else:
             assert not eq(model_trace.nodes[name]["value"],
                           guide_trace.nodes[name]["value"])
Beispiel #15
0
 def test_replay_partial(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model,
                                                guide_trace,
                                                sites=self.partial_sample_sites)).get_trace()
     for name in self.full_sample_sites.keys():
         if name in self.partial_sample_sites:
             assert_equal(model_trace.nodes[name]["value"],
                          guide_trace.nodes[name]["value"])
         else:
             assert not eq(model_trace.nodes[name]["value"],
                           guide_trace.nodes[name]["value"])
Beispiel #16
0
def test_replay(model, subsample_size):
    pyro.set_rng_seed(0)

    traced_model = poutine.trace(model)
    original = traced_model(subsample_size)

    replayed = poutine.replay(model, trace=traced_model.trace)(subsample_size)
    assert replayed == original

    if subsample_size < 20:
        different = traced_model(subsample_size)
        assert different != original
Beispiel #17
0
 def _traces(self, *args, **kwargs):
     """
     Generator of weighted samples from the proposal distribution.
     """
     for i in range(self.num_samples):
         guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
         model_trace = poutine.trace(
             poutine.replay(self.model,
                            trace=guide_trace)).get_trace(*args, **kwargs)
         log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum(
         )
         yield (model_trace, log_weight)
Beispiel #18
0
 def _traces(self, *args, **kwargs):
     if not self.posterior.exec_traces:
         self.posterior.run(*args, **kwargs)
     data_trace = poutine.trace(self.model).get_trace(*args, **kwargs)
     for _ in range(self.num_samples):
         model_trace = self.posterior().copy()
         self._remove_dropped_nodes(model_trace)
         self._adjust_to_data(model_trace, data_trace)
         resampled_trace = poutine.trace(
             poutine.replay(self.model,
                            model_trace)).get_trace(*args, **kwargs)
         yield (resampled_trace, 0., 0)
Beispiel #19
0
def test_replay(model, subsample_size):
    pyro.set_rng_seed(0)

    traced_model = poutine.trace(model)
    original = traced_model(subsample_size)

    replayed = poutine.replay(model, trace=traced_model.trace)(subsample_size)
    assert replayed == original

    if subsample_size < 20:
        different = traced_model(subsample_size)
        assert different != original
Beispiel #20
0
    def _get_one_posterior_sample(
        self,
        args,
        kwargs,
        return_sites: Optional[list] = None,
        return_observed: bool = False,
    ):
        """
        Get one sample from posterior distribution.

        Parameters
        ----------
        args
            arguments to model and guide
        kwargs
            arguments to model and guide
        return_sites
            List of variables for which to generate posterior samples, defaults to all variables.
        return_observed
            Record samples of observed variables.

        Returns
        -------
        Dictionary with a sample for each variable
        """
        if isinstance(self.module.guide, poutine.messenger.Messenger):
            # This already includes trace-replay behavior.
            sample = self.module.guide(*args, **kwargs)
        else:
            guide_trace = poutine.trace(self.module.guide).get_trace(
                *args, **kwargs)
            model_trace = poutine.trace(
                poutine.replay(self.module.model,
                               guide_trace)).get_trace(*args, **kwargs)
            sample = {
                name: site["value"]
                for name, site in model_trace.nodes.items() if
                ((site["type"] == "sample")  # sample statement
                 and ((return_sites is None) or
                      (name in return_sites))  # selected in return_sites list
                 and (((not site.get("is_observed", True)) or return_observed
                       )  # don't save observed unless requested
                      or (site.get("infer", False).get("_deterministic", False)
                          ))  # unless it is deterministic
                 and not isinstance(site.get(
                     "fn", None), poutine.subsample_messenger._Subsample
                                    )  # don't save plates
                 )
            }

        sample = {name: site.cpu().numpy() for name, site in sample.items()}

        return sample
Beispiel #21
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                # This iterates over a bag of traces, for each particle.
                for scale, guide_trace in iter_discrete_traces(
                        "flat", guide, *args, **kwargs):
                    model_trace = poutine.trace(
                        poutine.replay(model, guide_trace),
                        graph_type="flat").get_trace(*args, **kwargs)

                    check_model_guide_match(model_trace, guide_trace)
                    guide_trace = prune_subsample_sites(guide_trace)
                    model_trace = prune_subsample_sites(model_trace)
                    check_enum_discrete_can_run(model_trace, guide_trace)

                    log_r = model_trace.batch_log_pdf(
                    ) - guide_trace.batch_log_pdf()
                    weight = scale / self.num_particles
                    yield weight, model_trace, guide_trace, log_r
                continue

            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            log_r = model_trace.log_pdf() - guide_trace.log_pdf()
            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace, log_r
    def log_prob(self, data):
        self.reset_observations()
        self.observe(data=data)
        self.infer(3_000, train=False)

        conditioned_model, conditioned_guide = self.model_conditioned(anwers=self._observations, traits=self._observed_traits)

        guide_trace = poutine.trace(conditioned_guide).get_trace()
        model_trace = poutine.trace(
            poutine.replay(conditioned_model, trace=guide_trace)
        )

        return (model_trace.get_trace().log_prob_sum(lambda x,y: x is 'anwser')).detach().numpy()
Beispiel #23
0
    def log_prob(self, data):
        self.reset_observations()
        self.observe(data=data)
        self.infer(3_000)

        model_conditioned = self.model_conditioned()

        guide_trace = poutine.trace(self.traits_guide).get_trace()
        model_trace = poutine.trace(
            poutine.replay(model_conditioned, trace=guide_trace))

        return (model_trace.get_trace().log_prob_sum(
            lambda x, y: x is 'anwser')).detach().numpy()
Beispiel #24
0
def test_get_mask_optimization():
    def model():
        x = pyro.sample("x", dist.Normal(0, 1))
        pyro.sample("y", dist.Normal(x, 1), obs=torch.tensor(0.0))
        called.add("model-always")
        if poutine.get_mask() is not False:
            called.add("model-sometimes")
            pyro.factor("f", x + 1)

    def guide():
        x = pyro.sample("x", dist.Normal(0, 1))
        called.add("guide-always")
        if poutine.get_mask() is not False:
            called.add("guide-sometimes")
            pyro.factor("g", 2 - x)

    called = set()
    trace = poutine.trace(guide).get_trace()
    poutine.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" in called
    assert "guide-sometimes" in called

    called = set()
    with poutine.mask(mask=False):
        trace = poutine.trace(guide).get_trace()
        poutine.replay(model, trace)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called

    called = set()
    Predictive(model, guide=guide, num_samples=2, parallel=True)()
    assert "model-always" in called
    assert "guide-always" in called
    assert "model-sometimes" not in called
    assert "guide-sometimes" not in called
def test_compute_downstream_costs_plate_in_iplate(dim1):
    guide_trace = poutine.trace(
        nested_model_guide, graph_type="dense").get_trace(include_obs=False,
                                                          dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] -
                   guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] -
                   guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] -
                   guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] -
                    guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Beispiel #26
0
    def _get_matched_trace(model_trace, guide, args, kwargs):
        kwargs["observations"] = {}
        for node in model_trace.stochastic_nodes + model_trace.observation_nodes:
            if "was_observed" in model_trace.nodes[node]["infer"]:
                model_trace.nodes[node]["is_observed"] = True
                kwargs["observations"][node] = model_trace.nodes[node]["value"]

        guide_trace = poutine.trace(poutine.replay(guide,
                                                   model_trace)).get_trace(
                                                       *args, **kwargs)
        check_model_guide_match(model_trace, guide_trace)
        guide_trace = prune_subsample_sites(guide_trace)

        return guide_trace
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] -
                   guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = -guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] -
                    guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] -
                   guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Beispiel #28
0
            def compiled(unconstrained_params, *args):
                self = weakself()
                constrained_params = {}
                for name, unconstrained_param in zip(self._param_names,
                                                     unconstrained_params):
                    constrained_param = pyro.param(
                        name)  # assume param has been initialized
                    assert constrained_param.unconstrained(
                    ) is unconstrained_param
                    constrained_params[name] = constrained_param

                return poutine.replay(self.fn,
                                      params=constrained_params)(*args,
                                                                 **kwargs)
Beispiel #29
0
    def sample(self, params):
        """
        Samples parameters from the posterior distribution, when given existing parameters.

        :param dict params: Current parameter values.
        :param int time_step: Current time step.
        :return: New parameters from the posterior distribution.
        """
        old_model_trace = poutine.trace(self.model)(self.args, self.kwargs)
        traces = []
        t = 0
        i = 0
        while t < self.burn + self.lag * self.samples:
            i += 1
            # q(z' | z)
            new_guide_trace = poutine.block(poutine.trace(self.model))(
                old_model_trace, self.args, self.kwargs)
            # p(x, z')
            new_model_trace = poutine.trace(
                poutine.replay(self.model, new_guide_trace))(self.args,
                                                             self.kwargs)
            # q(z | z')
            old_guide_trace = poutine.block(
                poutine.trace(poutine.replay(self.guide, old_model_trace)))(
                    new_model_trace, self.args, self.kwargs)
            # p(x, z') q(z' | z) / p(x, z) q(z | z')
            logr = new_model_trace.log_pdf() + new_guide_trace.log_pdf() - \
                old_model_trace.log_pdf() - old_guide_trace.log_pdf()
            rnd = pyro.sample("mh_step_{}".format(i),
                              Uniform(torch.zeros(1), torch.ones(1)))

            if torch.log(rnd).data[0] < logr.data[0]:
                # accept
                t += 1
                old_model_trace = new_model_trace
                if t <= self.burn or (t > self.burn and t % self.lag == 0):
                    yield (new_model_trace, new_model_trace.log_pdf())
Beispiel #30
0
def relbo(model, guide, *args, **kwargs):

    approximation = kwargs.pop('approximation', None)
    # Run the guide with the arguments passed to SVI.step() and trace the execution,
    # i.e. record all the calls to Pyro primitives like sample() and param().
    guide_trace = trace(guide).get_trace(*args, **kwargs)
    # Now run the model with the same arguments and trace the execution. Because
    # model is being run with replay, whenever we encounter a sample site in the
    # model, instead of sampling from the corresponding distribution in the model,
    # we instead reuse the corresponding sample from the guide. In probabilistic
    # terms, this means our loss is constructed as an expectation w.r.t. the joint
    # distribution defined by the guide.
    model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
    approximation_trace = trace(
        replay(block(approximation, expose=["obs"]),
               guide_trace)).get_trace(*args, **kwargs)
    # We will accumulate the various terms of the ELBO in `elbo`.
    elbo = 0.
    # Loop over all the sample sites in the model and add the corresponding
    # log p(z) term to the ELBO. Note that this will also include any observed
    # data, i.e. sample sites with the keyword `obs=...`.
    elbo = elbo + model_trace.log_prob_sum()
    # Loop over all the sample sites in the guide and add the corresponding
    # -log q(z) term to the ELBO.
    elbo = elbo - guide_trace.log_prob_sum()
    elbo = elbo - approximation_trace.log_prob_sum()

    # Return (-elbo) since by convention we do gradient descent on a loss and
    # the ELBO is a lower bound that needs to be maximized.
    if elbo < 10e-8 and PRINT_TRACES:
        print('Guide trace')
        print(guide_trace.log_prob_sum())
        print('Model trace')
        print(model_trace.log_prob_sum())
        print('Approximation trace')
        print(approximation_trace.log_prob_sum())
    return -elbo
Beispiel #31
0
def test_scores(auto_class):
    def model():
        pyro.sample("z", dist.Normal(0.0, 1.0))

    guide = auto_class(model)
    guide_trace = poutine.trace(guide).get_trace()
    model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace()

    guide_trace.compute_log_prob()
    model_trace.compute_log_prob()

    assert '_auto_latent' not in model_trace.nodes
    assert model_trace.nodes['z']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['_auto_latent']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0
Beispiel #32
0
    def step(self, *args, **kwargs):
        """
        Take a filtering step using sequential importance resampling updating the
        particle weights and values while resampling if desired.
        Any args or kwargs are passed to the model and guide
        """
        with poutine.block(), self.particle_plate:
            guide_trace = poutine.trace(self.guide.step).get_trace(
                *args, **kwargs)
            model = poutine.replay(self.model.step, guide_trace)
            model_trace = poutine.trace(model).get_trace(*args, **kwargs)

        self._update_weights(model_trace, guide_trace)
        self._values.update(_extract_samples(model_trace))
        self._maybe_importance_resample()
Beispiel #33
0
def test_scores(auto_class):
    def model():
        pyro.sample("z", dist.Normal(0.0, 1.0))

    guide = auto_class(model)
    guide_trace = poutine.trace(guide).get_trace()
    model_trace = poutine.trace(poutine.replay(model, guide_trace)).get_trace()

    guide_trace.compute_log_prob()
    model_trace.compute_log_prob()

    assert '_auto_latent' not in model_trace.nodes
    assert model_trace.nodes['z']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['_auto_latent']['log_prob_sum'].item() != 0.0
    assert guide_trace.nodes['z']['log_prob_sum'].item() == 0.0
Beispiel #34
0
    def loss_fn(model, guide, *args, **kwargs):
        guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(poutine.replay(
            model, trace=guide_trace)).get_trace(*args, **kwargs)

        model_trace.compute_log_prob(lambda site, node: site in sites_to_map)

        loss = 0.0
        for site, node in model_trace.nodes.items():
            if site in sites_to_map:
                loss -= node['log_prob_sum'] * prior_weight
            if site == obs:
                loss -= node['fn'].log_prob(node['value']).sum()

        return loss
Beispiel #35
0
    def pol_att(self, x, y, t, e):

        num_samples = self.num_samples
        if not torch._C._get_tracing_state():
            assert x.dim() == 2 and x.size(-1) == self.feature_dim

        dataloader = [x]
        print("Evaluating {} minibatches".format(len(dataloader)))
        result_pol = []
        result_eatt = []
        for x in dataloader:
            # x = self.whiten(x)
            with pyro.plate("num_particles", num_samples, dim=-2):
                with poutine.trace() as tr, poutine.block(hide=["y", "t"]):
                    self.guide(x)
                with poutine.do(data=dict(t=torch.zeros(()))):
                    y0 = poutine.replay(self.model.y_mean, tr.trace)(x)
                with poutine.do(data=dict(t=torch.ones(()))):
                    y1 = poutine.replay(self.model.y_mean, tr.trace)(x)

            ite = (y1 - y0).mean(0)
            ite[t > 0] = -ite[t > 0]
            eatt = torch.abs(torch.mean(ite[(t + e) > 1]))
            pols = []
            for s in range(num_samples):
                pols.append(policy_val(ypred1=y1[s], ypred0=y0[s], y=y, t=t))

            pol = torch.stack(pols).mean(0)

            if not torch._C._get_tracing_state():
                print("batch eATT = {:0.6g}".format(eatt))
                print("batch RPOL = {:0.6g}".format(pol))
            result_pol.append(pol)
            result_eatt.append(eatt)

        return torch.stack(result_pol), torch.stack(result_eatt)
Beispiel #36
0
    def _get_traces(self, model, guide, args, kwargs):
        if self.max_plate_nesting == float("inf"):
            with validation_enabled(False):  # Avoid calling .log_prob() when undefined.
                # TODO factor this out as a stand-alone helper.
                ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
        vectorize = pyro.plate("num_particles_vectorized", self.num_particles,
                               dim=-self.max_plate_nesting)

        # Trace the guide as in ELBO.
        with poutine.trace() as tr, vectorize:
            guide(*args, **kwargs)
        guide_trace = tr.trace

        # Trace the model, drawing posterior predictive samples.
        with poutine.trace() as tr, poutine.uncondition():
            with poutine.replay(trace=guide_trace), vectorize:
                model(*args, **kwargs)
        model_trace = tr.trace
        for site in model_trace.nodes.values():
            if site["type"] == "sample" and site["infer"].get("was_observed", False):
                site["is_observed"] = True
        if is_validation_enabled():
            check_model_guide_match(model_trace, guide_trace, self.max_plate_nesting)

        guide_trace = prune_subsample_sites(guide_trace)
        model_trace = prune_subsample_sites(model_trace)
        if is_validation_enabled():
            for site in guide_trace.nodes.values():
                if site["type"] == "sample":
                    warn_if_nan(site["value"], site["name"])
                    if not getattr(site["fn"], "has_rsample", False):
                        raise ValueError("EnergyDistance requires fully reparametrized guides")
            for trace in model_trace.nodes.values():
                if site["type"] == "sample":
                    if site["is_observed"]:
                        warn_if_nan(site["value"], site["name"])
                        if not getattr(site["fn"], "has_rsample", False):
                            raise ValueError("EnergyDistance requires reparametrized likelihoods")

        if self.prior_scale > 0:
            model_trace.compute_log_prob(site_filter=lambda name, site: not site["is_observed"])
            if is_validation_enabled():
                for site in model_trace.nodes.values():
                    if site["type"] == "sample":
                        if not site["is_observed"]:
                            check_site_shape(site, self.max_plate_nesting)

        return guide_trace, model_trace
Beispiel #37
0
    def plot_recons(dataloader, mode):
        epoch = train_engine.state.epoch
        for batch in dataloader:
            x = batch[0]
            x = x.cuda()
            break
        x = x[:64]
        tb.add_image(f"{mode}/originals", torchvision.utils.make_grid(x),
                     epoch)
        bwd_trace = poutine.trace(backward_model).get_trace(x, **svi_args)
        fwd_trace = poutine.trace(
            poutine.replay(forward_model,
                           trace=bwd_trace)).get_trace(x, **svi_args)
        recon = fwd_trace.nodes["pixels"]["fn"].mean
        tb.add_image(f"{mode}/recons", torchvision.utils.make_grid(recon),
                     epoch)

        canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(
            f"{mode}/canonical_recon",
            torchvision.utils.make_grid(canonical_recon),
            epoch,
        )

        # sample from the prior

        prior_sample_args = {}
        prior_sample_args.update(svi_args)
        prior_sample_args["cond"] = False
        prior_sample_args["cond_label"] = False
        fwd_trace = poutine.trace(forward_model).get_trace(
            x, **prior_sample_args)
        prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
        prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image(f"{mode}/prior_samples",
                     torchvision.utils.make_grid(prior_sample), epoch)

        tb.add_image(
            f"{mode}/canonical_prior_samples",
            torchvision.utils.make_grid(prior_canonical_sample),
            epoch,
        )
        tb.add_image(
            f"{mode}/input_view",
            torchvision.utils.make_grid(
                bwd_trace.nodes["attention_input"]["value"]),
            epoch,
        )
Beispiel #38
0
    def forward(self, *args, **kwargs):
        total_samples = {}
        for i in range(self.num_samples):
            if i % 50 == 0:
                print("done with {}".format(i))
            guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(self.model,
                                                       guide_trace)).get_trace(
                                                           *args, **kwargs)
            for site in prune_subsample_sites(model_trace).stochastic_nodes:
                if site not in total_samples:
                    total_samples[site] = []
                total_samples[site].append(model_trace.nodes[site]["value"])
        for key in total_samples.keys():
            total_samples[key] = torch.stack(total_samples[key])

        return total_samples
    def multiple_predict(self, img):
        img_array = np.asarray(
            [np.copy(img) for _ in range(self.repeat_count)])
        img_array = torch.from_numpy(img_array)

        trace = poutine.trace(self.air.guide).get_trace(img_array, None)
        z, recons = poutine.replay(self.air.prior,
                                   trace=trace)(img_array.size(0))
        z_wheres = tensor_to_objs(latents_to_tensor(z))
        bboxes_frame = []
        for counter, z in enumerate(z_wheres):
            bboxes = ModelPipeline.get_bounding_boxes_for_image(img, z)
            # img2 = draw_bboxes_img(img, bboxes)
            bboxes_frame.append(bboxes)
        #         plt.imshow(img2)
        #         plt.show()
        return bboxes_frame
Beispiel #40
0
    def init(self, *args, **kwargs):
        """
        Perform any initialization for sequential importance resampling.
        Any args or kwargs are passed to the model and guide
        """
        self.particle_plate = pyro.plate("particles",
                                         self.num_particles,
                                         dim=-1 - self.max_plate_nesting)
        with poutine.block(), self.particle_plate:
            guide_trace = poutine.trace(self.guide.init).get_trace(
                *args, **kwargs)
            model = poutine.replay(self.model.init, guide_trace)
            model_trace = poutine.trace(model).get_trace(*args, **kwargs)

        self._update_weights(model_trace, guide_trace)
        self._values.update(_extract_samples(model_trace))
        self._maybe_importance_resample()
Beispiel #41
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide,
                             first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args,
                                                    **kwargs):
                model_trace = poutine.trace(poutine.replay(model,
                                                           trace=guide_trace),
                                            graph_type="flat").get_trace(
                                                *args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace,
                                            self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn(
                            'TraceEnum_ELBO found no sample sites configured for enumeration. '
                            'If you want to enumerate sites, you need to @config_enumerate or set '
                            'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                            'If you do not want to enumerate, consider using Trace_ELBO instead.'
                        )

                yield model_trace, guide_trace
def test_compute_downstream_costs_duplicates(dim):
    guide_trace = poutine.trace(diamond_guide,
                                graph_type="dense").get_trace(dim=dim)
    model_trace = poutine.trace(poutine.replay(diamond_model, trace=guide_trace),
                                graph_type="dense").get_trace(dim=dim)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_a1 = (model_trace.nodes['a1']['log_prob'] - guide_trace.nodes['a1']['log_prob'])
    for d in range(dim):
        expected_a1 += model_trace.nodes['b{}'.format(d)]['log_prob']
        expected_a1 -= guide_trace.nodes['b{}'.format(d)]['log_prob']
    expected_a1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_a1 += model_trace.nodes['obs']['log_prob']

    expected_b1 = - guide_trace.nodes['b1']['log_prob']
    for d in range(dim):
        expected_b1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_b1 += model_trace.nodes['obs']['log_prob']

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    for d in range(dim):
        expected_c1 += model_trace.nodes['b{}'.format(d)]['log_prob']
    expected_c1 += model_trace.nodes['obs']['log_prob']

    assert_equal(expected_a1, dc['a1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Beispiel #43
0
    def plot_images(engine):
        epoch = train_engine.state.epoch
        for x, y in test_dl:
            x = x.cuda()
            y = y.cuda()
            break
        x = x[:32]
        tb.add_image("originals", torchvision.utils.make_grid(x), epoch)
        bwd_trace = poutine.trace(backward_model).get_trace(x,
                                                            y,
                                                            N=x.shape[0],
                                                            **svi_args)
        fwd_trace = poutine.trace(
            poutine.replay(forward_model,
                           trace=bwd_trace)).get_trace(x,
                                                       y,
                                                       N=x.shape[0],
                                                       **svi_args)
        recon = fwd_trace.nodes["pixels"]["fn"].mean
        tb.add_image("recons", torchvision.utils.make_grid(recon), epoch)

        canonical_recon = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image("canonical_recon",
                     torchvision.utils.make_grid(canonical_recon), epoch)

        # sample from the prior

        prior_sample_args = {}
        prior_sample_args.update(svi_args)
        prior_sample_args["cond"] = False
        prior_sample_args["cond_label"] = False
        fwd_trace = poutine.trace(forward_model).get_trace(x,
                                                           y,
                                                           N=x.shape[0],
                                                           **prior_sample_args)
        prior_sample = fwd_trace.nodes["pixels"]["fn"].mean
        prior_canonical_sample = fwd_trace.nodes["canonical_view"]["value"]
        tb.add_image("prior_samples",
                     torchvision.utils.make_grid(prior_sample), epoch)

        tb.add_image(
            "canonical_prior_samples",
            torchvision.utils.make_grid(prior_canonical_sample),
            epoch,
        )
def test_compute_downstream_costs_iplate_in_plate(dim1, dim2):
    guide_trace = poutine.trace(
        nested_model_guide2, graph_type="dense").get_trace(include_obs=False,
                                                           dim1=dim1,
                                                           dim2=dim2)
    model_trace = poutine.trace(poutine.replay(nested_model_guide2,
                                               trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True,
                                                              dim1=dim1,
                                                              dim2=dim2)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(
        model_trace, guide_trace, non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    for k in dc:
        assert (guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])

    expected_b1 = model_trace.nodes['b1']['log_prob'] - guide_trace.nodes[
        'b1']['log_prob']
    expected_b1 += model_trace.nodes['obs1']['log_prob']
    assert_equal(expected_b1, dc['b1'])

    expected_c = model_trace.nodes['c']['log_prob'] - guide_trace.nodes['c'][
        'log_prob']
    for i in range(dim2):
        expected_c += model_trace.nodes['b{}'.format(i)]['log_prob'] - \
            guide_trace.nodes['b{}'.format(i)]['log_prob']
        expected_c += model_trace.nodes['obs{}'.format(i)]['log_prob']
    assert_equal(expected_c, dc['c'])

    expected_a1 = model_trace.nodes['a1']['log_prob'] - guide_trace.nodes[
        'a1']['log_prob']
    expected_a1 += expected_c.sum()
    assert_equal(expected_a1, dc['a1'])
Beispiel #45
0
    def classifier(self, num_samples=1000, temperature=1):

        guide_trace = poutine.trace(self.guide).get_trace()
        trained_model = poutine.replay(self.model,
                                       trace=guide_trace)  # replay the globals

        classes = []
        for n in range(num_samples):
            # avoid conflict with data plate
            inferred_model = infer_discrete(trained_model,
                                            temperature=temperature,
                                            first_available_dim=-1)
            trace = poutine.trace(inferred_model).get_trace()
            classes.append(trace.nodes["class"]["value"])

        self.classes = torch.stack(classes)

        return self.classes
def test_compute_downstream_costs_iarange_in_irange(dim1):
    guide_trace = poutine.trace(nested_model_guide,
                                graph_type="dense").get_trace(include_obs=False, dim1=dim1)
    model_trace = poutine.trace(poutine.replay(nested_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, dim1=dim1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()

    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)
    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])
    expected_c1 += model_trace.nodes['obs1']['log_prob']

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
    expected_b1 += model_trace.nodes['obs1']['log_prob'].sum()

    expected_c0 = (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob'])
    expected_c0 += model_trace.nodes['obs0']['log_prob']

    expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
    expected_b0 += (model_trace.nodes['c0']['log_prob'] - guide_trace.nodes['c0']['log_prob']).sum()
    expected_b0 += model_trace.nodes['obs0']['log_prob'].sum()

    assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)
    assert_equal(expected_c0, dc['c0'], prec=1.0e-6)
    assert_equal(expected_b0, dc['b0'], prec=1.0e-6)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Beispiel #47
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a trace generator
        """
        # enable parallel enumeration
        guide = poutine.enum(guide, first_available_dim=self.max_iarange_nesting)

        for i in range(self.num_particles):
            for guide_trace in iter_discrete_traces("flat", guide, *args, **kwargs):
                model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
                                            graph_type="flat").get_trace(*args, **kwargs)

                if is_validation_enabled():
                    check_model_guide_match(model_trace, guide_trace, self.max_iarange_nesting)
                guide_trace = prune_subsample_sites(guide_trace)
                model_trace = prune_subsample_sites(model_trace)
                if is_validation_enabled():
                    check_traceenum_requirements(model_trace, guide_trace)

                model_trace.compute_log_prob()
                guide_trace.compute_score_parts()
                if is_validation_enabled():
                    for site in model_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                    any_enumerated = False
                    for site in guide_trace.nodes.values():
                        if site["type"] == "sample":
                            check_site_shape(site, self.max_iarange_nesting)
                            if site["infer"].get("enumerate"):
                                any_enumerated = True
                    if self.strict_enumeration_warning and not any_enumerated:
                        warnings.warn('TraceEnum_ELBO found no sample sites configured for enumeration. '
                                      'If you want to enumerate sites, you need to @config_enumerate or set '
                                      'infer={"enumerate": "sequential"} or infer={"enumerate": "parallel"}? '
                                      'If you do not want to enumerate, consider using Trace_ELBO instead.')

                yield model_trace, guide_trace
Beispiel #48
0
    def _get_traces(self, model, guide, *args, **kwargs):
        """
        runs the guide and runs the model against the guide with
        the result packaged as a tracegraph generator

        XXX support for automatically settings args/kwargs to volatile?
        """

        for i in range(self.num_particles):
            if self.enum_discrete:
                raise NotImplementedError("https://github.com/uber/pyro/issues/220")

            guide_trace = poutine.trace(guide,
                                        graph_type="dense").get_trace(*args, **kwargs)
            model_trace = poutine.trace(poutine.replay(model, guide_trace),
                                        graph_type="dense").get_trace(*args, **kwargs)

            check_model_guide_match(model_trace, guide_trace)
            guide_trace = prune_subsample_sites(guide_trace)
            model_trace = prune_subsample_sites(model_trace)

            weight = 1.0 / self.num_particles
            yield weight, model_trace, guide_trace
Beispiel #49
0
def iter_discrete_traces(graph_type, fn, *args, **kwargs):
    """
    Iterate over all discrete choices of a stochastic function.

    When sampling continuous random variables, this behaves like `fn`.
    When sampling discrete random variables, this iterates over all choices.

    This yields `(scale, trace)` pairs, where `scale` is the probability of the
    discrete choices made in the `trace`.

    :param str graph_type: The type of the graph, e.g. "flat" or "dense".
    :param callable fn: A stochastic function.
    :returns: An iterator over (scale, trace) pairs.
    """
    queue = LifoQueue()
    queue.put(Trace())
    while not queue.empty():
        partial_trace = queue.get()
        escape_fn = functools.partial(util.discrete_escape, partial_trace)
        traced_fn = poutine.trace(poutine.escape(poutine.replay(fn, partial_trace), escape_fn),
                                  graph_type=graph_type)
        try:
            full_trace = traced_fn.get_trace(*args, **kwargs)
        except util.NonlocalExit as e:
            for extended_trace in util.enum_extend(traced_fn.trace.copy(), e.site):
                queue.put(extended_trace)
            continue

        # Scale trace by probability of discrete choices.
        log_pdf = full_trace.batch_log_pdf(site_filter=site_is_discrete)
        if isinstance(log_pdf, float):
            log_pdf = torch.Tensor([log_pdf])
        if isinstance(log_pdf, torch.Tensor):
            log_pdf = Variable(log_pdf)
        scale = torch.exp(log_pdf.detach())
        yield scale, full_trace
Beispiel #50
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:
        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = ['window_size',
                      'rnn_hidden_size',
                      'decoder_output_bias',
                      'decoder_output_use_sigmoid',
                      'baseline_scalar',
                      'encoder_net',
                      'decoder_net',
                      'predict_net',
                      'embed_net',
                      'bl_predict_net',
                      'non_linearity',
                      'pos_prior_mean',
                      'pos_prior_sd',
                      'scale_prior_mean',
                      'scale_prior_sd']
    model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model, air.guide,
              optim.Adam(per_param_optim_args),
              loss=TraceGraph_ELBO())

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
def test_compute_downstream_costs_big_model_guide_pair(include_inner_1, include_single, flip_c23,
                                                       include_triple, include_z1):
    guide_trace = poutine.trace(big_model_guide,
                                graph_type="dense").get_trace(include_obs=False, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)
    model_trace = poutine.trace(poutine.replay(big_model_guide, trace=guide_trace),
                                graph_type="dense").get_trace(include_obs=True, include_inner_1=include_inner_1,
                                                              include_single=include_single, flip_c23=flip_c23,
                                                              include_triple=include_triple, include_z1=include_z1)

    guide_trace = prune_subsample_sites(guide_trace)
    model_trace = prune_subsample_sites(model_trace)
    model_trace.compute_log_prob()
    guide_trace.compute_log_prob()
    non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)

    dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
                                             non_reparam_nodes)

    dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace,
                                                                     non_reparam_nodes)

    assert dc_nodes == dc_nodes_brute

    expected_nodes_full_model = {'a1': {'c2', 'a1', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'b0'}, 'd2': {'obs', 'd2'},
                                 'd1': {'obs', 'd1', 'd2'}, 'c3': {'d2', 'obs', 'd1', 'c3'},
                                 'b0': {'b0', 'd1', 'c1', 'obs', 'b1', 'd2', 'c3', 'c2'},
                                 'b1': {'obs', 'b1', 'd1', 'd2', 'c3', 'c1', 'c2'},
                                 'c1': {'d1', 'c1', 'obs', 'd2', 'c3', 'c2'},
                                 'c2': {'obs', 'd1', 'c3', 'd2', 'c2'}}
    if not include_triple and include_inner_1 and include_single and not flip_c23:
        assert(dc_nodes == expected_nodes_full_model)

    expected_b1 = (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob'])
    expected_b1 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
    expected_b1 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
    expected_b1 += model_trace.nodes['obs']['log_prob'].sum(0, keepdim=False)
    if include_inner_1:
        expected_b1 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum(0)
        expected_b1 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum(0)
    assert_equal(expected_b1, dc['b1'], prec=1.0e-6)

    if include_single:
        expected_b0 = (model_trace.nodes['b0']['log_prob'] - guide_trace.nodes['b0']['log_prob'])
        expected_b0 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum()
        expected_b0 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum()
        expected_b0 += model_trace.nodes['obs']['log_prob'].sum()
        if include_inner_1:
            expected_b0 += (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']).sum()
            expected_b0 += (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']).sum()
        assert_equal(expected_b0, dc['b0'], prec=1.0e-6)
        assert dc['b0'].size() == (5,)

    if include_inner_1:
        expected_c3 = (model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob'])
        expected_c3 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c3 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c3 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c2 = (model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob'])
        expected_c2 += (model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']).sum(0)
        expected_c2 += (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']).sum(0)
        expected_c2 += model_trace.nodes['obs']['log_prob'].sum(0)

        expected_c1 = (model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob'])

        if flip_c23:
            expected_c3 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
            expected_c2 += model_trace.nodes['c3']['log_prob']
        else:
            expected_c2 += model_trace.nodes['c3']['log_prob'] - guide_trace.nodes['c3']['log_prob']
            expected_c2 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
        expected_c1 += expected_c3

        assert_equal(expected_c1, dc['c1'], prec=1.0e-6)
        assert_equal(expected_c2, dc['c2'], prec=1.0e-6)
        assert_equal(expected_c3, dc['c3'], prec=1.0e-6)

    expected_d1 = model_trace.nodes['d1']['log_prob'] - guide_trace.nodes['d1']['log_prob']
    expected_d1 += model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob']
    expected_d1 += model_trace.nodes['obs']['log_prob']

    expected_d2 = (model_trace.nodes['d2']['log_prob'] - guide_trace.nodes['d2']['log_prob'])
    expected_d2 += model_trace.nodes['obs']['log_prob']

    if include_triple:
        expected_z0 = dc['a1'] + model_trace.nodes['z0']['log_prob'] - guide_trace.nodes['z0']['log_prob']
        assert_equal(expected_z0, dc['z0'], prec=1.0e-6)
    assert_equal(expected_d2, dc['d2'], prec=1.0e-6)
    assert_equal(expected_d1, dc['d1'], prec=1.0e-6)

    assert dc['b1'].size() == (2,)
    assert dc['d2'].size() == (4, 2)

    for k in dc:
        assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
        assert_equal(dc[k], dc_brute[k])
Beispiel #52
0
svi = SVI(air.model, air.guide,
          optim.Adam(per_param_optim_args),
          loss='ELBO',
          trace_graph=True)

for i in range(1, args.num_steps + 1):

    loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

    if args.progress_every > 0 and i % args.progress_every == 0:
        print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
            i,
            (i * args.batch_size) / X_size,
            (time.time() - t0) / 3600,
            loss / X_size))

    if args.viz and i % args.viz_every == 0:
        trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
        z, recons = poutine.replay(air.prior, trace)(examples_to_viz.size(0))
        z_wheres = post_process_latents(z)

        # Show data with inferred objection positions.
        vis.images(draw_many(examples_to_viz, z_wheres))
        # Show reconstructions of data.
        vis.images(draw_many(recons, z_wheres))

    if 'save' in args and i % args.save_every == 0:
        print('Saving parameters...')
        torch.save(air.state_dict(), args.save)
Beispiel #53
0
 def test_replay_full(self):
     guide_trace = poutine.trace(self.guide).get_trace()
     model_trace = poutine.trace(poutine.replay(self.model, trace=guide_trace)).get_trace()
     for name in self.full_sample_sites.keys():
         assert_equal(model_trace.nodes[name]["value"],
                      guide_trace.nodes[name]["value"])