def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self.locs = PyroModule() self.scales = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = (biject_to(site["fn"].support).inv( site["value"].detach()).detach()) event_dim = site["fn"].event_dim + init_loc.dim( ) - site["value"].dim() self._event_dims[name] = event_dim # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = getattr(frame, "full_size", frame.size) if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) deep_setattr( self.scales, name, PyroParam(init_scale, self.scale_constraint, event_dim), )
def __init__(self): super().__init__() self.x = nn.Parameter(torch.tensor(0.)) self.m = torch.nn.Linear(2, 3) self.m.weight.data.fill_(1.) self.m.bias.data.fill_(2.) self.p = PyroModule() self.p.x = nn.Parameter(torch.tensor(3.))
def __init__(self): super().__init__() self.x = nn.Parameter(torch.tensor(0.)) self.y = PyroParam(torch.tensor(1.), constraint=constraints.positive) self.m = nn.Module() self.m.u = nn.Parameter(torch.tensor(2.0)) self.p = PyroModule() self.p.v = nn.Parameter(torch.tensor(3.)) self.p.w = PyroParam(torch.tensor(4.), constraint=constraints.positive)
def test_torch_serialize(): module = PyroModule() module.x = PyroParam(torch.tensor(1.234), constraints.positive) module.y = nn.Parameter(torch.randn(3)) assert isinstance(module.x, torch.Tensor) # Work around https://github.com/pytorch/pytorch/issues/27972 with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) f = io.BytesIO() torch.save(module, f) pyro.clear_param_store() f.seek(0) actual = torch.load(f) assert_equal(actual.x, module.x) actual_names = {name for name, _ in actual.named_parameters()} expected_names = {name for name, _ in module.named_parameters()} assert actual_names == expected_names
def test_constraints(shape, constraint_): module = PyroModule() module.x = PyroParam(torch.full(shape, 1e-4), constraint_) assert isinstance(module.x, torch.Tensor) assert isinstance(module.x_unconstrained, nn.Parameter) assert module.x.shape == shape assert constraint_.check(module.x).all() module.x = torch.randn(shape).exp() * 1e-6 assert isinstance(module.x_unconstrained, nn.Parameter) assert isinstance(module.x, torch.Tensor) assert module.x.shape == shape assert constraint_.check(module.x).all() assert isinstance(module.x_unconstrained, torch.Tensor) y = module.x_unconstrained.data.normal_() assert_equal(module.x.data, transform_to(constraint_)(y)) assert constraint_.check(module.x).all() del module.x assert 'x' not in module._pyro_params assert not hasattr(module, 'x') assert not hasattr(module, 'x_unconstrained')
def test_cache(): class MyModule(PyroModule): def forward(self): return [self.gather(), self.gather()] def gather(self): return { "a": self.a, "b": self.b, "c": self.c, "p.d": self.p.d, "p.e": self.p.e, "p.f": self.p.f, } module = MyModule() module.a = nn.Parameter(torch.tensor(0.)) module.b = PyroParam(torch.tensor(1.), constraint=constraints.positive) module.c = PyroSample(dist.Normal(0, 1)) module.p = PyroModule() module.p.d = nn.Parameter(torch.tensor(3.)) module.p.e = PyroParam(torch.tensor(4.), constraint=constraints.positive) module.p.f = PyroSample(dist.Normal(0, 1)) assert module._pyro_context is module.p._pyro_context # Check that results are cached with an invocation of .__call__(). result1 = module() actual, expected = result1 for key in ["a", "c", "p.d", "p.f"]: assert actual[key] is expected[key], key # Check that results are not cached across invocations of .__call__(). result2 = module() for key in ["b", "c", "p.e", "p.f"]: assert result1[0] is not result2[0], key
def test_submodule_contains_torch_module(): submodule = PyroModule() submodule.linear = nn.Linear(1, 1) module = PyroModule() module.child = submodule
def test_delete(): m = PyroModule() m.a = PyroParam(torch.tensor(1.)) del m.a m.a = PyroParam(torch.tensor(0.1)) assert_equal(m.a.detach(), torch.tensor(0.1))
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self.locs = PyroModule() self.scales = PyroModule() self.scale_trils = PyroModule() self.conds = PyroModule() self.deps = PyroModule() self._batch_shapes = {} self._unconstrained_event_shapes = {} sample_sites = OrderedDict( self.prototype_trace.iter_stochastic_nodes()) self._auto_config(sample_sites, args, kwargs) # Collect unconstrained shapes. init_locs = {} numel = {} for name, site in sample_sites.items(): with helpful_support_errors(site): init_loc = (biject_to(site["fn"].support).inv( site["value"].detach()).detach()) self._batch_shapes[name] = site["fn"].batch_shape self._unconstrained_event_shapes[name] = init_loc.shape[ len(site["fn"].batch_shape):] numel[name] = init_loc.numel() init_locs[name] = init_loc.reshape(-1) # Initialize guide params. children = defaultdict(list) num_pending = {} for name, site in sample_sites.items(): # Initialize location parameters. init_loc = init_locs[name] deep_setattr(self.locs, name, PyroParam(init_loc)) # Initialize parameters of conditional distributions. conditional = self.conditionals[name] if callable(conditional): deep_setattr(self.conds, name, conditional) else: if conditional not in ("delta", "normal", "mvn"): raise ValueError( f"Unsupported conditional type: {conditional}") if conditional in ("normal", "mvn"): init_scale = torch.full_like(init_loc, self._init_scale) deep_setattr(self.scales, name, PyroParam(init_scale, self.scale_constraint)) if conditional == "mvn": init_scale_tril = eye_like(init_loc, init_loc.numel()) deep_setattr( self.scale_trils, name, PyroParam(init_scale_tril, self.scale_tril_constraint), ) # Initialize dependencies on upstream variables. num_pending[name] = 0 deps = PyroModule() deep_setattr(self.deps, name, deps) for upstream, dep in self.dependencies.get(name, {}).items(): assert upstream in sample_sites children[upstream].append(name) num_pending[name] += 1 if isinstance(dep, str) and dep == "linear": dep = torch.nn.Linear(numel[upstream], numel[name], bias=False) dep.weight.data.zero_() elif not callable(dep): raise ValueError( f"Expected either the string 'linear' or a callable, but got {dep}" ) deep_setattr(deps, upstream, dep) # Topologically sort sites. # TODO should we choose a more optimal structure? self._sorted_sites = [] while num_pending: name, count = min(num_pending.items(), key=lambda kv: (kv[1], kv[0])) assert count == 0, f"cyclic dependency: {name}" del num_pending[name] for child in children[name]: num_pending[child] -= 1 site = self._compress_site(sample_sites[name]) self._sorted_sites.append((name, site)) # Prune non-essential parts of the trace to save memory. for name, site in self.prototype_trace.nodes.items(): site.clear()
def _getattr(obj, attr): obj_next = getattr(obj, attr, None) if obj_next is not None: return obj_next setattr(obj, attr, PyroModule()) return getattr(obj, attr)
def _setup_prototype(self, *args, **kwargs) -> None: super()._setup_prototype(*args, **kwargs) self.locs = PyroModule() self.scales = PyroModule() self.white_vecs = PyroModule() self.prec_sqrts = PyroModule() self._factors = OrderedDict() self._plates = OrderedDict() self._event_numel = OrderedDict() self._unconstrained_event_shapes = OrderedDict() # Trace model dependencies. model = self._original_model[0] self._original_model = None self.dependencies = poutine.block(get_dependencies)(model, args, kwargs)[ "prior_dependencies" ] # Eliminate observations with no upstream latents. for d, upstreams in list(self.dependencies.items()): if all(self.prototype_trace.nodes[u]["is_observed"] for u in upstreams): del self.dependencies[d] del self.prototype_trace.nodes[d] # Collect factors and plates. for d, site in self.prototype_trace.nodes.items(): # Prune non-essential parts of the trace to save memory. pruned_site, site = site, site.copy() pruned_site.clear() # Collect factors and plates. if site["type"] != "sample" or site_is_subsample(site): continue assert all(f.vectorized for f in site["cond_indep_stack"]) self._factors[d] = self._compress_site(site) plates = frozenset(site["cond_indep_stack"]) if site["fn"].batch_shape != _plates_to_shape(plates): raise ValueError( f"Shape mismatch at site '{d}'. " "Are you missing a pyro.plate() or .to_event()?" ) if site["is_observed"]: # Break irrelevant observation plates. plates &= frozenset().union( *(self._plates[u] for u in self.dependencies[d] if u != d) ) self._plates[d] = plates # Create location-scale parameters, one per latent variable. if site["is_observed"]: # This may slightly overestimate, e.g. for Multinomial. self._event_numel[d] = site["fn"].event_shape.numel() # Account for broken irrelevant observation plates. for f in set(site["cond_indep_stack"]) - plates: self._event_numel[d] *= f.size continue with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"]).detach() batch_shape = site["fn"].batch_shape event_shape = init_loc.shape[len(batch_shape) :] self._unconstrained_event_shapes[d] = event_shape self._event_numel[d] = event_shape.numel() event_dim = len(event_shape) deep_setattr(self.locs, d, PyroParam(init_loc, event_dim=event_dim)) deep_setattr( self.scales, d, PyroParam( torch.full_like(init_loc, self._init_scale), constraint=self.scale_constraint, event_dim=event_dim, ), ) # Create parameters for dependencies, one per factor. for d, site in self._factors.items(): u_size = 0 for u in self.dependencies[d]: if not self._factors[u]["is_observed"]: broken_shape = _plates_to_shape(self._plates[u] - self._plates[d]) u_size += broken_shape.numel() * self._event_numel[u] d_size = self._event_numel[d] if site["is_observed"]: d_size = min(d_size, u_size) # just an optimization batch_shape = _plates_to_shape(self._plates[d]) # Create parameters of each Gaussian factor. white_vec = init_loc.new_zeros(batch_shape + (d_size,)) # We initialize with noise to avoid singular gradient. prec_sqrt = torch.rand( batch_shape + (u_size, d_size), dtype=init_loc.dtype, device=init_loc.device, ) prec_sqrt.sub_(0.5).mul_(self._init_scale) if not site["is_observed"]: # Initialize the [d,d] block to the identity matrix. prec_sqrt.diagonal(dim1=-2, dim2=-1).fill_(1) deep_setattr(self.white_vecs, d, PyroParam(white_vec, event_dim=1)) deep_setattr(self.prec_sqrts, d, PyroParam(prec_sqrt, event_dim=2))