Esempio n. 1
0
def random_module(name, nn_module, prior, *args, **kwargs):
    """
    Places a prior over the parameters of the module `nn_module`.
    Returns a distribution (callable) over `nn.Module`s, which
    upon calling returns a sampled `nn.Module`.

    See the `Bayesian Regression tutorial <http://pyro.ai/examples/bayesian_regression.html>`_
    for an example.

    :param name: name of pyro module
    :type name: str
    :param nn_module: the module to be registered with pyro
    :type nn_module: torch.nn.Module
    :param prior: pyro distribution, stochastic function, or python dict with parameter names
                  as keys and respective distributions/stochastic functions as values.
    :returns: a callable which returns a sampled module
    """
    assert hasattr(nn_module, "parameters"), "Module is not a NN module."
    # register params in param store
    lifted_fn = poutine.lift(module, prior)

    def _fn():
        nn_copy = copy.deepcopy(nn_module)
        # update_module_params must be True or the lifted module will not update local params
        return lifted_fn(name,
                         nn_copy,
                         update_module_params=True,
                         *args,
                         **kwargs)

    return _fn
Esempio n. 2
0
def random_module(name, nn_module, prior, *args, **kwargs):
    """
    Places a prior over the parameters of the module `nn_module`.
    Returns a distribution (callable) over `nn.Module`s, which
    upon calling returns a sampled `nn.Module`.

    See the `Bayesian Regression tutorial <http://pyro.ai/examples/bayesian_regression.html>`_
    for an example.

    :param name: name of pyro module
    :type name: str
    :param nn_module: the module to be registered with pyro
    :type nn_module: torch.nn.Module
    :param prior: pyro distribution, stochastic function, or python dict with parameter names
                  as keys and respective distributions/stochastic functions as values.
    :returns: a callable which returns a sampled module
    """
    assert hasattr(nn_module, "parameters"), "Module is not a NN module."
    # register params in param store
    lifted_fn = poutine.lift(module, prior)

    def _fn():
        nn_copy = copy.deepcopy(nn_module)
        # update_module_params must be True or the lifted module will not update local params
        return lifted_fn(name, nn_copy, update_module_params=True, *args, **kwargs)
    return _fn
Esempio n. 3
0
def random_module(name, nn_module, prior, *args, **kwargs):
    """
    Places a prior over the parameters of the module `nn_module`.

    See the `Bayesian Regression <http://pyro.ai/examples/bayesian_regression.html>`_
    tutorial for an example.

    :param name: name of pyro module
    :type name: str
    :param nn_module: the module to be registered with pyro
    :type nn_module: torch.nn.Module
    :param prior: prior distribution or iterable over distributions
    :returns: a callable which returns a sampled module
    """
    assert hasattr(nn_module, "parameters"), "Module is not a NN module."
    # register params in param store
    lifted_fn = poutine.lift(module, prior)

    def _fn():
        nn_copy = copy.deepcopy(nn_module)
        # update_module_params must be True or the lifted module will not update local params
        return lifted_fn(name,
                         nn_copy,
                         update_module_params=True,
                         *args,
                         **kwargs)

    return _fn
Esempio n. 4
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('loc1', 'loc2', 'scale1', 'scale2'):
             assert name not in lifted_tr
         else:
             assert name in lifted_tr
Esempio n. 5
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('mu1', 'mu2', 'sigma1', 'sigma2'):
             self.assertFalse(name in lifted_tr)
         else:
             self.assertTrue(name in lifted_tr)
Esempio n. 6
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('loc1', 'loc2', 'scale1', 'scale2'):
             assert name not in lifted_tr
         else:
             assert name in lifted_tr
Esempio n. 7
0
 def test_splice(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide,
                                            prior=self.prior)).get_trace()
     for name in tr.nodes.keys():
         if name in ('mu1', 'mu2', 'sigma1', 'sigma2'):
             self.assertFalse(name in lifted_tr)
         else:
             self.assertTrue(name in lifted_tr)
Esempio n. 8
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in {'scale1', 'loc1', 'scale2', 'loc2'}:
             assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
         if tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
Esempio n. 9
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in {'scale1', 'loc1', 'scale2', 'loc2'}:
             assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
         if tr.nodes[name]["type"] == "param":
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
Esempio n. 10
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         self.assertTrue(name in lifted_tr)
         if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
             self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
         if tr.nodes[name]["type"] == "param":
             self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
                             not lifted_tr.nodes[name]["is_observed"])
Esempio n. 11
0
    def _traces(self, *args, **kwargs):
        # find good initial trace
        model_trace = poutine.trace(self.model).get_trace(*args, **kwargs)
        best_log_prob = model_trace.log_prob_sum()
        for i in range(20):
            trace = poutine.trace(self.model).get_trace(*args, **kwargs)
            log_prob = trace.log_prob_sum()
            if log_prob > best_log_prob:
                best_log_prob = log_prob
                model_trace = trace

        # lift model
        prior, unpacked = {}, {}
        param_constraints = pyro.get_param_store().get_state()["constraints"]
        for name, node in model_trace.nodes.items():
            if node["type"] == "param":
                if param_constraints[name] is constraints.positive:
                    prior[name] = dist.HalfCauchy(2)
                else:
                    prior[name] = dist.Normal(0, 10)
                unpacked[name] = pyro.param(name).unconstrained()
            elif name in self.start:
                unpacked[name] = self.start[name]
            elif node["type"] == "sample" and not node["is_observed"]:
                unpacked[name] = transform_to(node["fn"].support).inv(
                    node["value"])
        lifted_model = poutine.lift(self.model, prior)

        # define guide
        packed = torch.cat(
            [v.clone().detach().reshape(-1) for v in unpacked.values()])
        pyro.param("auto_loc", packed)
        delta_guide = AutoLaplaceApproximation(lifted_model)

        # train guide
        optimizer = torch.optim.LBFGS(
            (pyro.param("auto_loc").unconstrained(), ), lr=0.1, max_iter=500)
        loss_and_grads = Trace_ELBO().loss_and_grads

        def closure():
            optimizer.zero_grad()
            return loss_and_grads(lifted_model, delta_guide, *args, **kwargs)

        optimizer.step(closure)
        guide = delta_guide.laplace_approximation(*args, **kwargs)

        # get posterior
        for i in range(self.num_samples):
            guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
            model_poutine = poutine.trace(
                poutine.replay(lifted_model, trace=guide_trace))
            yield model_poutine.get_trace(*args, **kwargs), 1.0

        pyro.clear_param_store()
Esempio n. 12
0
 def test_unlifted_param(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(poutine.lift(self.guide, prior=self.partial_dict)).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in ('sigma1', 'mu1'):
             assert name + "_prior" == lifted_tr.nodes[name]['fn'].__name__
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
         if name in ('sigma2', 'mu2'):
             assert lifted_tr.nodes[name]["type"] == "param"
Esempio n. 13
0
 def test_prior_dict(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(
         poutine.lift(self.guide, prior=self.prior_dict)).get_trace()
     for name in tr.nodes.keys():
         self.assertTrue(name in lifted_tr)
         if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
             self.assertTrue(
                 name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
         if tr.nodes[name]["type"] == "param":
             self.assertTrue(lifted_tr.nodes[name]["type"] == "sample"
                             and not lifted_tr.nodes[name]["is_observed"])
Esempio n. 14
0
 def test_unlifted_param(self):
     tr = poutine.trace(self.guide).get_trace()
     lifted_tr = poutine.trace(
         poutine.lift(self.guide, prior=self.partial_dict)
     ).get_trace()
     for name in tr.nodes.keys():
         assert name in lifted_tr
         if name in ("scale1", "loc1"):
             assert name + "_prior" == lifted_tr.nodes[name]["fn"].__name__
             assert lifted_tr.nodes[name]["type"] == "sample"
             assert not lifted_tr.nodes[name]["is_observed"]
         if name in ("scale2", "loc2"):
             assert lifted_tr.nodes[name]["type"] == "param"
Esempio n. 15
0
def sim(posterior, data=None, n=1000):
    obs_node = posterior.exec_traces[0].observation_nodes[-1]
    obs = []
    if data is None:
        for i in range(n):
            idx = posterior._categorical.sample().item()
            trace = posterior.exec_traces[idx]
            obs.append(trace.nodes[obs_node]["fn"].sample())
    else:
        data[obs_node] = None
        predictive = TracePredictive(
            poutine.lift(posterior.model, lambda: None), posterior,
            n).run(**data)
        for trace in predictive.exec_traces:
            obs.append(trace.nodes[obs_node]["value"])
    return torch.stack(obs).detach()
Esempio n. 16
0
def sim(posterior, data=None, n=1000):
    obs_node = posterior.exec_traces[0].observation_nodes[-1]
    obs = []
    if data is None:
        for i in range(n):
            idx = posterior._categorical.sample().item()
            trace = posterior.exec_traces[idx]
            obs.append(trace.nodes[obs_node]["fn"].sample())
    else:
        data = {name: data[name] if name in data else None
                for name in inspect.signature(posterior.model).parameters}
        predictive = TracePredictive(poutine.lift(posterior.model, dist.Normal(0, 1)),
                                     posterior, n).run(**data)
        for trace in predictive.exec_traces:
            obs.append(trace.nodes[obs_node]["value"])
    return torch.stack(obs).detach()
Esempio n. 17
0
def random_module(name, nn_module, prior, *args, **kwargs):
    r"""
    .. warning::
        The `random_module` primitive is deprecated, and will be removed
        in a future release. Use :class:`~pyro.nn.module.PyroModule` instead to
        to create Bayesian modules from :class:`torch.nn.Module` instances.
        See the `Bayesian Regression tutorial <http://pyro.ai/examples/bayesian_regression.html>`_
        for an example.

    DEPRECATED Places a prior over the parameters of the module `nn_module`.
    Returns a distribution (callable) over `nn.Module`\s, which upon calling
    returns a sampled `nn.Module`.

    :param name: name of pyro module
    :type name: str
    :param nn_module: the module to be registered with pyro
    :type nn_module: torch.nn.Module
    :param prior: pyro distribution, stochastic function, or python dict with parameter names
                  as keys and respective distributions/stochastic functions as values.
    :returns: a callable which returns a sampled module
    """
    warnings.warn(
        "The `random_module` primitive is deprecated, and will be removed "
        "in a future release. Use `pyro.nn.Module` to create Bayesian "
        "modules from `torch.nn.Module` instances.",
        FutureWarning,
    )

    assert hasattr(nn_module, "parameters"), "Module is not a NN module."
    # register params in param store
    lifted_fn = poutine.lift(module, prior=prior)

    def _fn():
        nn_copy = copy.deepcopy(nn_module)
        # update_module_params must be True or the lifted module will not update local params
        return lifted_fn(name,
                         nn_copy,
                         update_module_params=True,
                         *args,
                         **kwargs)

    return _fn
Esempio n. 18
0
 def test_memoize(self):
     poutine.trace(
         poutine.lift(self.dup_param_guide, prior=dist.Normal(0, 1)))()