Exemple #1
0
 def __init__(self, size):
     super().__init__()
     self.x = PyroParam(torch.zeros(size))
     self.y = PyroParam(lambda: torch.randn(size))
     self.z = PyroParam(
         torch.ones(size), constraint=constraints.positive, event_dim=1
     )
     self.s = PyroSample(dist.Normal(0, 1))
     self.t = PyroSample(lambda self: dist.Normal(self.s, self.z))
Exemple #2
0
 def __init__(self):
     super().__init__()
     self.a = torch.nn.Parameter(torch.zeros(2))
     self.register_buffer("b", torch.zeros(3))
     self.c = torch.randn(4)  # this wouldn't work with torch.nn.Module.to()
     self.d = dist.Normal(0, 1)
     self.e = PyroParam(
         torch.randn(()),
         constraint=constraints.greater_than(torch.tensor(0.5)),
     )
     self.f = PyroSample(dist.Normal(0, 1))
     self.g = PyroSample(lambda self: dist.Normal(self.f, 1))
Exemple #3
0
 def update_(self, net):
     """Replaces PyroSample attributes on a given PyroModule net according to the hide/expose logic and
     the classes' prior_dist method."""
     for module_name, module in net.named_modules():
         for site_name, site in list(util.named_pyro_samples(module, recurse=False)):
             full_name = module_name + "." + site_name
             if self.expose_fn(module, full_name):
                 prior_dist = self.prior_dist(full_name, module, site)
                 setattr(module, site_name, PyroSample(prior_dist))
Exemple #4
0
 def apply_(self, net):
     """"Replaces all nn.Parameter attributes on a given PyroModule net according to the hide/expose logic and
     the classes' prior_dist method."""
     for module_name, module in net.named_modules():
         for param_name, param in list(module.named_parameters(recurse=False)):
             full_name = module_name + "." + param_name
             if self.expose_fn(module, full_name):
                 prior_dist = self.prior_dist(full_name, module, param).expand(param.shape).to_event(param.dim())
                 setattr(module, param_name, PyroSample(prior_dist))
             else:
                 setattr(module, param_name, PyroParam(param.data.detach()))
Exemple #5
0
    def set_prior(self, name, prior):
        """
        Sets prior for a parameter.

        :param str name: Name of the parameter.
        :param ~pyro.distributions.distribution.Distribution prior: A Pyro prior
            distribution.
        """
        warnings.warn("The method `self.set_prior({}, prior)` has been deprecated"
                      " in favor of `self.{} = PyroSample(prior)`.".format(name, name), UserWarning)
        setattr(self, name, PyroSample(prior))