def test_to_pyro_module_(): pyro.set_rng_seed(123) actual = nn.Sequential( nn.Linear(28 * 28, 200), nn.Sigmoid(), nn.Linear(200, 200), nn.Sigmoid(), nn.Linear(200, 10), ) to_pyro_module_(actual) pyro.clear_param_store() pyro.set_rng_seed(123) expected = PyroModule[nn.Sequential]( PyroModule[nn.Linear](28 * 28, 200), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](200, 200), PyroModule[nn.Sigmoid](), PyroModule[nn.Linear](200, 10), ) pyro.clear_param_store() def assert_identical(a, e): assert type(a) is type(e) if isinstance(a, dict): assert set(a) == set(e) for key in a: assert_identical(a[key], e[key]) elif isinstance(a, nn.Module): assert_identical(a.__dict__, e.__dict__) elif isinstance(a, (str, int, float, torch.Tensor)): assert_equal(a, e) assert_identical(actual, expected) # check output data = torch.randn(28 * 28) actual_out = actual(data) pyro.clear_param_store() expected_out = expected(data) assert_equal(actual_out, expected_out) # check randomization def randomize(model): for m in model.modules(): for name, value in list(m.named_parameters(recurse=False)): setattr( m, name, PyroSample( prior=dist.Normal(0, 1) .expand(value.shape) .to_event(value.dim()) ), ) randomize(actual) randomize(expected) assert_identical(actual, expected)
def as_pyro_module(module): to_pyro_module_(module, recurse=True) for m in module.modules(): for n, p in list(m.named_parameters(recurse=False)): setattr( m, n, PyroSample( dist.Normal(torch.zeros_like(p), torch.ones_like(p)).to_event())) return module
def test_name_preserved_by_to_pyro_module(): features = torch.randn(4) data = torch.randn(3) class Model(PyroModule): def __init__(self): super().__init__() self.scale = PyroParam(torch.ones(3), constraints.positive) self.loc = torch.nn.Linear(4, 3) def forward(self, features, data): loc = self.loc(features) scale = self.scale with pyro.plate("data", len(data)): pyro.sample("obs", dist.Normal(loc, scale), obs=data) model = Model() params = list(model.parameters()) param_names = set() def optim_config(param_name): param_names.add(param_name) return {"lr": 0.0} # Record while model.loc is an nn.Module. loss = poutine.trace(model).get_trace(features, data).log_prob_sum() loss.backward() adam = optim.Adam(optim_config) adam(params) assert param_names expected_param_names = param_names.copy() del adam, loss param_names.clear() pyro.clear_param_store() # Record while model.loc is a PyroModule. to_pyro_module_(model.loc) loss = poutine.trace(model).get_trace(features, data).log_prob_sum() loss.backward() adam = optim.Adam(optim_config) adam(params) assert param_names actual_param_names = param_names.copy() del adam, loss param_names.clear() pyro.clear_param_store() assert actual_param_names == {"scale", "loc.weight", "loc.bias"} assert actual_param_names == expected_param_names
def test_bayesian_gru(): input_size = 2 hidden_size = 3 batch_size = 4 seq_len = 5 # Construct a simple GRU. gru = nn.GRU(input_size, hidden_size) input_ = torch.randn(seq_len, batch_size, input_size) output, _ = gru(input_) assert output.shape == (seq_len, batch_size, hidden_size) output2, _ = gru(input_) assert torch.allclose(output2, output) # Make it Bayesian. to_pyro_module_(gru) for name, value in list(gru.named_parameters(recurse=False)): prior = dist.Normal(0, 1).expand(value.shape).to_event(value.dim()) setattr(gru, name, PyroSample(prior=prior)) output, _ = gru(input_) assert output.shape == (seq_len, batch_size, hidden_size) output2, _ = gru(input_) assert not torch.allclose(output2, output)
def encode(self, name: str, prior: Distribution): """ Apply encoder network to input data to obtain hidden layer encoding. Parameters ---------- args Pyro model args kwargs Pyro model kwargs ------- """ try: args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict) # get the data for NN from in_names = self.amortised_plate_sites["input"] x_in = [ kwargs[i] if i in kwargs.keys() else args[i] for i in in_names ] # apply data transform before passing to NN in_transforms = self.amortised_plate_sites["input_transform"] x_in = [in_transforms[i](x) for i, x in enumerate(x_in)] # apply learnable normalisation before passing to NN: input_normalisation = self.amortised_plate_sites.get( "input_normalisation", None) if input_normalisation is not None: for i in range(len(self.amortised_plate_sites["input"])): if input_normalisation[i]: x_in[i] = x_in[i] * deep_getattr( self, f"input_normalisation_{i}") if "single" in self.encoder_mode: # encode with a single encoder res = deep_getattr(self, "one_encoder")(*x_in) if "multiple" in self.encoder_mode: # when there is a second layer of multiple encoders fetch encoders and encode data x_in[0] = res res = deep_getattr(self.multiple_encoders, name)(*x_in) else: # when there are multiple encoders fetch encoders and encode data res = deep_getattr(self.multiple_encoders, name)(*x_in) return res except AttributeError: pass # Initialize. # create normalisation parameters if necessary: input_normalisation = self.amortised_plate_sites.get( "input_normalisation", None) if input_normalisation is not None: for i in range(len(self.amortised_plate_sites["input"])): if input_normalisation[i]: deep_setattr( self, f"input_normalisation_{i}", PyroParam( torch.ones((1, self.single_n_in)).to( prior.mean.device).requires_grad_(True)), ) # create encoder neural networks if "single" in self.encoder_mode: if self.encoder_instance is not None: # copy provided encoder instance one_encoder = deepcopy(self.encoder_instance).to( prior.mean.device) # convert to pyro module to_pyro_module_(one_encoder) deep_setattr(self, "one_encoder", one_encoder) else: # create encoder instance from encoder class deep_setattr( self, "one_encoder", self.encoder_class(n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs).to( prior.mean.device), ) if "multiple" in self.encoder_mode: # determine the number of hidden layers if name in self.n_hidden.keys(): n_hidden = self.n_hidden[name] else: n_hidden = self.n_hidden["multiple"] multi_encoder_kwargs = deepcopy(self.multi_encoder_kwargs) multi_encoder_kwargs["n_hidden"] = n_hidden # create multiple encoders if self.encoder_instance is not None: # copy instances encoder_ = deepcopy(self.encoder_instance).to( prior.mean.device) # convert to pyro module to_pyro_module_(encoder_) deep_setattr( self, "multiple_encoders." + name, encoder_, ) else: # create instances deep_setattr( self, "multiple_encoders." + name, self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **multi_encoder_kwargs).to( prior.mean.device), ) return self.encode(name, prior)
def __init__( self, model, amortised_plate_sites: dict, n_in: int, n_hidden: dict = None, init_param=0, init_param_scale: float = 1 / 50, scales_offset: float = -2, encoder_class=FCLayersPyro, encoder_kwargs=None, multi_encoder_kwargs=None, encoder_instance: torch.nn.Module = None, create_plates=None, encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", ): """ Parameters ---------- model Pyro model amortised_plate_sites Dictionary with amortised plate details: the name of observation/minibatch plate, indexes of model args to provide to encoder, variable names that belong to the observation plate and the number of dimensions in non-plate axis of each variable - such as: { "name": "obs_plate", "input": [0], # expression data + (optional) batch index ([0, 2]) "input_transform": [torch.log1p], # how to transform input data before passing to NN "sites": { "n_s_cells_per_location": 1, "y_s_groups_per_location": 1, "z_sr_groups_factors": self.n_groups, "w_sf": self.n_factors, "l_s_add": 1, } } n_in Number of input dimensions (for encoder_class). n_hidden Number of hidden nodes in each layer, one of 3 options: 1. Integer denoting the number of hidden nodes 2. Dictionary with {"single": 200, "multiple": 200} denoting the number of hidden nodes for each `encoder_mode` (See below) 3. Allowing different number of hidden nodes for each model site. Dictionary with the number of hidden nodes for single encode mode and each model site: { "single": 200 "n_s_cells_per_location": 5, "y_s_groups_per_location": 5, "z_sr_groups_factors": 128, "w_sf": 128, "l_s_add": 5, } init_param Not implemented yet - initial values for amortised variables. init_param_scale How to scale/normalise initial values for weights converting hidden layers to mean and sd. encoder_class Class for defining encoder network. encoder_kwargs Keyword arguments for encoder_class. multi_encoder_kwargs Optional separate keyword arguments for encoder_class, useful when encoder_mode == "single-multiple". encoder_instance Encoder network instance, overrides class input and the input instance is copied with deepcopy. create_plates Function for creating plates encoder_mode Use single encoder for all variables ("single"), one encoder per variable ("multiple") or a single encoder in the first step and multiple encoders in the second step ("single-multiple"). """ super().__init__(model, create_plates=create_plates) self.amortised_plate_sites = amortised_plate_sites self.encoder_mode = encoder_mode self.scales_offset = scales_offset self.softplus = SoftplusTransform() if n_hidden is None: n_hidden = {"single": 200, "multiple": 200} else: if isinstance(n_hidden, int): n_hidden = {"single": n_hidden, "multiple": n_hidden} elif not isinstance(n_hidden, dict): raise ValueError("n_hidden must be either in or dict") encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() encoder_kwargs["n_hidden"] = n_hidden["single"] self.encoder_kwargs = encoder_kwargs if multi_encoder_kwargs is None: multi_encoder_kwargs = deepcopy(encoder_kwargs) self.multi_encoder_kwargs = multi_encoder_kwargs if "multiple" in n_hidden.keys(): self.multi_encoder_kwargs["n_hidden"] = n_hidden["multiple"] self.single_n_in = n_in self.multiple_n_in = n_in self.n_out = ( np.sum([np.sum(amortised_plate_sites["sites"][k]) for k in amortised_plate_sites["sites"].keys()]) * 2 ) self.n_hidden = n_hidden self.encoder_class = encoder_class self.encoder_instance = encoder_instance if "single" in self.encoder_mode: # create a single encoder NN if encoder_instance is not None: self.one_encoder = deepcopy(encoder_instance) # convert to pyro module to_pyro_module_(self.one_encoder) else: self.one_encoder = encoder_class( n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs ) if "multiple" in self.encoder_mode: self.multiple_n_in = self.n_hidden["single"] self.init_param_scale = init_param_scale
def _setup_prototype(self, *args, **kwargs): super()._setup_prototype(*args, **kwargs) self._event_dims = {} self._cond_indep_stacks = {} self.hidden2locs = PyroModule() self.hidden2scales = PyroModule() if "multiple" in self.encoder_mode: # create module for collecting multiple encoder NN self.multiple_encoders = PyroModule() # Initialize guide params for name, site in self.prototype_trace.iter_stochastic_nodes(): # Collect unconstrained event_dims, which may differ from constrained event_dims. with helpful_support_errors(site): init_loc = biject_to(site["fn"].support).inv(site["value"].detach()).detach() event_dim = site["fn"].event_dim + init_loc.dim() - site["value"].dim() self._event_dims[name] = event_dim # Collect independence contexts. self._cond_indep_stacks[name] = site["cond_indep_stack"] # determine the number of hidden layers if "multiple" in self.encoder_mode: if "multiple" in self.n_hidden.keys(): n_hidden = self.n_hidden["multiple"] else: n_hidden = self.n_hidden[name] elif "single" in self.encoder_mode: n_hidden = self.n_hidden["single"] # add linear layer for locs and scales param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2locs, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) init_param = np.random.normal( np.zeros(param_dim), (np.ones(param_dim) * self.init_param_scale) / np.sqrt(n_hidden), ).astype("float32") _deep_setattr( self.hidden2scales, name, PyroParam(torch.tensor(init_param, device=site["value"].device, requires_grad=True)), ) if "multiple" in self.encoder_mode: # create multiple encoders if self.encoder_instance is not None: # copy instances encoder_ = deepcopy(self.encoder_instance).to(site["value"].device) # convert to pyro module to_pyro_module_(encoder_) _deep_setattr( self.multiple_encoders, name, encoder_, ) else: # create instances _deep_setattr( self.multiple_encoders, name, self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **self.multi_encoder_kwargs).to( site["value"].device ), )