def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.locs = PyroModule() self.scales = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. init_loc = biject_to(site["fn"].support).inv( site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim( ) - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # If subsampling, repeat init_value to full size. for frame in site["cond_indep_stack"]: full_size = getattr(frame, "full_size", frame.size) if full_size != frame.size: dim = frame.dim - event_dim init_loc = periodic_repeat(init_loc, full_size, dim).contiguous() init_scale = torch.full_like(init_loc, self._init_scale) _deep_setattr(self.locs, name, PyroParam(init_loc, constraints.real, event_dim)) _deep_setattr( self.scales, name, PyroParam(init_scale, constraints.positive, event_dim))
def _getattr(obj, attr): obj_next = getattr(obj, attr, None) if obj_next is not None: return obj_next setattr(obj, attr, PyroModule()) return getattr(obj, attr)
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.hidden2locs = PyroModule() self.hidden2scales = PyroModule() if "multiple" in self.encoder_mode: # create module for collecting multiple encoder NN self.multiple_encoders = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # determine the number of hidden layers if "multiple" in self.encoder_mode: if "multiple" in self.n_hidden.keys(): n_hidden = self.n_hidden["multiple"] else: n_hidden = self.n_hidden[name] elif "single" in self.encoder_mode: n_hidden = self.n_hidden["single"] # add linear layer for locs and scales param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2locs, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2scales, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) if "multiple" in self.encoder_mode: # create multiple encoders if self.encoder_instance is not None: # copy instances encoder_ = deepcopy(self.encoder_instance).to(site["value"].device) # convert to pyro module to_pyro_module_(encoder_) _deep_setattr( self.multiple_encoders, name, encoder_, ) else: # create instances _deep_setattr( self.multiple_encoders, name, self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to( site["value"].device ), )