def test_hmc(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, y, kernel, X, likelihood) else: gp = model_class(X, y, kernel, likelihood) kernel.set_prior("variance", dist.Uniform(torch.tensor(0.5), torch.tensor(1.5))) kernel.set_prior("lengthscale", dist.Uniform(torch.tensor(1.0), torch.tensor(3.0))) hmc_kernel = HMC(gp.model, step_size=1) mcmc_run = MCMC(hmc_kernel, num_samples=10) post_trace = defaultdict(list) for trace, _ in mcmc_run._traces(): variance_name = param_with_module_name(kernel.name, "variance") post_trace["variance"].append(trace.nodes[variance_name]["value"]) lengthscale_name = param_with_module_name(kernel.name, "lengthscale") post_trace["lengthscale"].append( trace.nodes[lengthscale_name]["value"]) if model_class is VariationalGP: f_name = param_with_module_name(gp.name, "f") post_trace["f"].append(trace.nodes[f_name]["value"]) if model_class is VariationalSparseGP: u_name = param_with_module_name(gp.name, "u") post_trace["u"].append(trace.nodes[u_name]["value"]) for param in post_trace: param_mean = torch.mean(torch.stack(post_trace[param]), 0) logger.info("Posterior mean - {}".format(param)) logger.info(param_mean)
def test_hmc(model_class, X, y, kernel, likelihood): if model_class is SparseGPRegression or model_class is VariationalSparseGP: gp = model_class(X, y, kernel, X, likelihood) else: gp = model_class(X, y, kernel, likelihood) kernel.set_prior("variance", dist.Uniform(torch.tensor(0.5), torch.tensor(1.5))) kernel.set_prior("lengthscale", dist.Uniform(torch.tensor(1.0), torch.tensor(3.0))) hmc_kernel = HMC(gp.model, step_size=1) mcmc_run = MCMC(hmc_kernel, num_samples=10) post_trace = defaultdict(list) for trace, _ in mcmc_run._traces(): variance_name = param_with_module_name(kernel.name, "variance") post_trace["variance"].append(trace.nodes[variance_name]["value"]) lengthscale_name = param_with_module_name(kernel.name, "lengthscale") post_trace["lengthscale"].append(trace.nodes[lengthscale_name]["value"]) if model_class is VariationalGP: f_name = param_with_module_name(gp.name, "f") post_trace["f"].append(trace.nodes[f_name]["value"]) if model_class is VariationalSparseGP: u_name = param_with_module_name(gp.name, "u") post_trace["u"].append(trace.nodes[u_name]["value"]) for param in post_trace: param_mean = torch.mean(torch.stack(post_trace[param]), 0) logger.info("Posterior mean - {}".format(param)) logger.info(param_mean)
def _register_param(self, param, mode="model"): """ Registers a parameter to Pyro. It can be seen as a wrapper for :func:`pyro.param` and :func:`pyro.sample` primitives. :param str param: Name of the parameter. :param str mode: Either "model" or "guide". """ if param in self._fixed_params: self._registered_params[param] = self._fixed_params[param] return prior = self._priors.get(param) if self.name is None: param_name = param else: param_name = param_with_module_name(self.name, param) if prior is None: constraint = self._constraints.get(param) default_value = getattr(self, param) if constraint is None: p = pyro.param(param_name, default_value) else: p = pyro.param(param_name, default_value, constraint=constraint) elif mode == "model": p = pyro.sample(param_name, prior) else: # prior != None and mode = "guide" MAP_param_name = param_name + "_MAP" # TODO: consider to init parameter from a prior call instead of mean MAP_param = pyro.param(MAP_param_name, prior.mean.detach()) p = pyro.sample(param_name, dist.Delta(MAP_param)) self._registered_params[param] = p
def model(self): self.set_mode("model") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") N = self.X.shape[0] Kff = self.kernel(self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) * self.jitter) Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(f_loc.shape) f_name = param_with_module_name(self.name, "f") if self.whiten: Id = torch.eye(N, out=self.X.new_empty(N, N)) pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(f_scale_tril) else: pyro.sample(f_name, dist.MultivariateNormal(zero_loc, scale_tril=Lff) .independent(zero_loc.dim() - 1)) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.whiten: f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
def model(self): self.set_mode("model") Xu = self.get_param("Xu") u_loc = self.get_param("u_loc") u_scale_tril = self.get_param("u_scale_tril") M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) zero_loc = Xu.new_zeros(u_loc.shape) u_name = param_with_module_name(self.name, "u") if self.whiten: Id = torch.eye(M, out=Xu.new_empty(M, M)) pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim() - 1)) else: pyro.sample(u_name, dist.MultivariateNormal(zero_loc, scale_tril=Luu) .independent(zero_loc.dim() - 1)) f_loc, f_var = conditional(self.X, Xu, self.kernel, u_loc, u_scale_tril, Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: with poutine.scale(None, self.num_data / self.X.shape[0]): return self.likelihood(f_loc, f_var, self.y)
def module(name, nn_module, update_module_params=False): """ Takes a torch.nn.Module and registers its parameters with the ParamStore. In conjunction with the ParamStore save() and load() functionality, this allows the user to save and load modules. :param name: name of module :type name: str :param nn_module: the module to be registered with Pyro :type nn_module: torch.nn.Module :param update_module_params: determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to `False` :type load_from_param_store: bool :returns: torch.nn.Module """ assert hasattr(nn_module, "parameters"), "module has no parameters" assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\ _MODULE_NAMESPACE_DIVIDER if isclass(nn_module): raise NotImplementedError( "pyro.module does not support class constructors for " + "the argument nn_module") target_state_dict = OrderedDict() for param_name, param_value in nn_module.named_parameters(): if param_value.requires_grad: # register the parameter in the module with pyro # this only does something substantive if the parameter hasn't been seen before full_param_name = param_with_module_name(name, param_name) returned_param = param(full_param_name, param_value) if param_value._cdata != returned_param._cdata: target_state_dict[param_name] = returned_param else: warnings.warn("{} was not registered in the param store because". format(param_name) + " requires_grad=False") if target_state_dict and update_module_params: # WARNING: this is very dangerous. better method? for _name, _param in nn_module.named_parameters(): is_param = False name_arr = _name.rsplit('.', 1) if len(name_arr) > 1: mod_name, param_name = name_arr[0], name_arr[1] else: is_param = True mod_name = _name if _name in target_state_dict.keys(): if not is_param: deep_getattr( nn_module, mod_name )._parameters[param_name] = target_state_dict[_name] else: nn_module._parameters[mod_name] = target_state_dict[_name] return nn_module
def module(name, nn_module, tags="default", update_module_params=False): """ Takes a torch.nn.Module and registers its parameters with the ParamStore. In conjunction with the ParamStore save() and load() functionality, this allows the user to save and load modules. :param name: name of module :type name: str :param nn_module: the module to be registered with Pyro :type nn_module: torch.nn.Module :param tags: optional; tags to associate with any parameters inside the module :type tags: string or iterable of strings :param update_module_params: determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to `False` :type load_from_param_store: bool :returns: torch.nn.Module """ assert hasattr(nn_module, "parameters"), "module has no parameters" assert _MODULE_NAMESPACE_DIVIDER not in name, "improper module name, since contains %s" %\ _MODULE_NAMESPACE_DIVIDER if isclass(nn_module): raise NotImplementedError("pyro.module does not support class constructors for " + "the argument nn_module") target_state_dict = OrderedDict() for param_name, param_value in nn_module.named_parameters(): # register the parameter in the module with pyro # this only does something substantive if the parameter hasn't been seen before full_param_name = param_with_module_name(name, param_name) returned_param = param(full_param_name, param_value, tags=tags) if get_tensor_data(param_value)._cdata != get_tensor_data(returned_param)._cdata: target_state_dict[param_name] = returned_param if target_state_dict and update_module_params: # WARNING: this is very dangerous. better method? for _name, _param in nn_module.named_parameters(): is_param = False name_arr = _name.rsplit('.', 1) if len(name_arr) > 1: mod_name, param_name = name_arr[0], name_arr[1] else: is_param = True mod_name = _name if _name in target_state_dict.keys(): if not is_param: deep_getattr(nn_module, mod_name)._parameters[param_name] = target_state_dict[_name] else: nn_module._parameters[mod_name] = target_state_dict[_name] return nn_module
def guide(self): self.set_mode("guide") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") if self._sample_latent: f_name = param_with_module_name(self.name, "f") pyro.sample(f_name, dist.MultivariateNormal(f_loc, scale_tril=f_scale_tril) .independent(f_loc.dim()-1)) return f_loc, f_scale_tril
def guide(self): self.set_mode("guide") Xu = self.get_param("Xu") u_loc = self.get_param("u_loc") u_scale_tril = self.get_param("u_scale_tril") if self._sample_latent: u_name = param_with_module_name(self.name, "u") pyro.sample(u_name, dist.MultivariateNormal(u_loc, scale_tril=u_scale_tril) .independent(u_loc.dim()-1)) return Xu, u_loc, u_scale_tril
def model(self): self.set_mode("model", recursive=False) # sample X from unit multivariate normal distribution zero_loc = self.X_loc.new_zeros(self.X_loc.shape) C = self.X_loc.shape[1] Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) X_name = param_with_module_name(self.name, "X") X = pyro.sample(X_name, dist.MultivariateNormal(zero_loc, scale_tril=Id) .independent(zero_loc.dim()-1)) self.base_model.set_data(X, self.y) self.base_model.model()
def guide(self): self.set_mode("guide") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") if self._sample_latent: f_name = param_with_module_name(self.name, "f") pyro.sample( f_name, dist.MultivariateNormal( f_loc, scale_tril=f_scale_tril).independent(f_loc.dim() - 1)) return f_loc, f_scale_tril
def guide(self): self.set_mode("guide", recursive=False) # sample X from variational multivariate normal distribution X_loc = self.get_param("X_loc") X_scale_tril = self.get_param("X_scale_tril") X_name = param_with_module_name(self.name, "X") X = pyro.sample(X_name, dist.MultivariateNormal(X_loc, scale_tril=X_scale_tril) .independent(X_loc.dim()-1)) self.base_model.set_data(X, self.y) if self._call_base_model_guide: self.base_model.guide()
def guide(self): self.set_mode("guide", recursive=False) # sample X from variational multivariate normal distribution X_loc = self.get_param("X_loc") X_scale_tril = self.get_param("X_scale_tril") X_name = param_with_module_name(self.name, "X") X = pyro.sample( X_name, dist.MultivariateNormal( X_loc, scale_tril=X_scale_tril).independent(X_loc.dim() - 1)) self.base_model.set_data(X, self.y) if self._call_base_model_guide: self.base_model.guide()
def model(self): self.set_mode("model", recursive=False) # sample X from unit multivariate normal distribution zero_loc = self.X_loc.new_zeros(self.X_loc.shape) C = self.X_loc.shape[1] Id = torch.eye(C, out=self.X_loc.new_empty(C, C)) X_name = param_with_module_name(self.name, "X") X = pyro.sample( X_name, dist.MultivariateNormal( zero_loc, scale_tril=Id).independent(zero_loc.dim() - 1)) self.base_model.set_data(X, self.y) self.base_model.model()
def model(self): self.set_mode("model") Xu = self.get_param("Xu") noise = self.get_param("noise") # W = inv(Luu) @ Kuf # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W # Fomulas for each approximation method are # DTC: y_cov = Qff + noise, trace_term = 0 # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0 # VFE: y_cov = Qff + noise, trace_term = tr(Kff-Qff) / noise # y_cov = W.T @ W + D # trace_term is added into log_prob M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) Kuf = self.kernel(Xu, self.X) W = matrix_triangular_solve_compat(Kuf, Luu, upper=False) D = noise.expand(W.shape[1]) trace_term = 0 if self.approx == "FITC" or self.approx == "VFE": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=0) if self.approx == "FITC": D = D + Kffdiag - Qffdiag else: # approx = "VFE" trace_term += (Kffdiag - Qffdiag).sum() / noise zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = D + W.pow(2).sum(dim=0) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample( y_name, dist.LowRankMultivariateNormal( f_loc, W, D, trace_term).expand_by( self.y.shape[:-1]).independent(self.y.dim() - 1), obs=self.y)
def model(self): self.set_mode("model") Xu = self.get_param("Xu") noise = self.get_param("noise") # W = inv(Luu) @ Kuf # Qff = Kfu @ inv(Kuu) @ Kuf = W.T @ W # Fomulas for each approximation method are # DTC: y_cov = Qff + noise, trace_term = 0 # FITC: y_cov = Qff + diag(Kff - Qff) + noise, trace_term = 0 # VFE: y_cov = Qff + noise, trace_term = tr(Kff-Qff) / noise # y_cov = W.T @ W + D # trace_term is added into log_prob M = Xu.shape[0] Kuu = self.kernel(Xu) + torch.eye(M, out=Xu.new_empty(M, M)) * self.jitter Luu = Kuu.potrf(upper=False) Kuf = self.kernel(Xu, self.X) W = matrix_triangular_solve_compat(Kuf, Luu, upper=False) D = noise.expand(W.shape[1]) trace_term = 0 if self.approx == "FITC" or self.approx == "VFE": Kffdiag = self.kernel(self.X, diag=True) Qffdiag = W.pow(2).sum(dim=0) if self.approx == "FITC": D = D + Kffdiag - Qffdiag else: # approx = "VFE" trace_term += (Kffdiag - Qffdiag).sum() / noise zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = D + W.pow(2).sum(dim=0) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample(y_name, dist.LowRankMultivariateNormal(f_loc, W, D, trace_term) .expand_by(self.y.shape[:-1]) .independent(self.y.dim() - 1), obs=self.y)
def model(self): self.set_mode("model") noise = self.get_param("noise") Kff = self.kernel(self.X) + noise.expand(self.X.shape[0]).diag() Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = Lff.pow(2).sum(dim=-1) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample(y_name, dist.MultivariateNormal(f_loc, scale_tril=Lff) .expand_by(self.y.shape[:-1]) .independent(self.y.dim() - 1), obs=self.y)
def model(self): self.set_mode("model") noise = self.get_param("noise") Kff = self.kernel(self.X) + noise.expand(self.X.shape[0]).diag() Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(self.X.shape[0]) f_loc = zero_loc + self.mean_function(self.X) if self.y is None: f_var = Lff.pow(2).sum(dim=-1) return f_loc, f_var else: y_name = param_with_module_name(self.name, "y") return pyro.sample( y_name, dist.MultivariateNormal(f_loc, scale_tril=Lff).expand_by( self.y.shape[:-1]).independent(self.y.dim() - 1), obs=self.y)
def model(self): self.set_mode("model") f_loc = self.get_param("f_loc") f_scale_tril = self.get_param("f_scale_tril") N = self.X.shape[0] Kff = self.kernel( self.X) + (torch.eye(N, out=self.X.new_empty(N, N)) * self.jitter) Lff = Kff.potrf(upper=False) zero_loc = self.X.new_zeros(f_loc.shape) f_name = param_with_module_name(self.name, "f") if self.whiten: Id = torch.eye(N, out=self.X.new_empty(N, N)) pyro.sample( f_name, dist.MultivariateNormal( zero_loc, scale_tril=Id).independent(zero_loc.dim() - 1)) f_scale_tril = Lff.matmul(f_scale_tril) else: pyro.sample( f_name, dist.MultivariateNormal( zero_loc, scale_tril=Lff).independent(zero_loc.dim() - 1)) f_var = f_scale_tril.pow(2).sum(dim=-1) if self.whiten: f_loc = Lff.matmul(f_loc.unsqueeze(-1)).squeeze(-1) f_loc = f_loc + self.mean_function(self.X) if self.y is None: return f_loc, f_var else: return self.likelihood(f_loc, f_var, self.y)
def module(name, nn_module, update_module_params=False): """ Registers all parameters of a :class:`torch.nn.Module` with Pyro's :mod:`~pyro.params.param_store`. In conjunction with the :class:`~pyro.params.param_store.ParamStoreDict` :meth:`~pyro.params.param_store.ParamStoreDict.save` and :meth:`~pyro.params.param_store.ParamStoreDict.load` functionality, this allows the user to save and load modules. .. note:: Consider instead using :class:`~pyro.nn.module.PyroModule`, a newer alternative to ``pyro.module()`` that has better support for: jitting, serving in C++, and converting parameters to random variables. For details see the `Modules Tutorial <https://pyro.ai/examples/modules.html>`_ . :param name: name of module :type name: str :param nn_module: the module to be registered with Pyro :type nn_module: torch.nn.Module :param update_module_params: determines whether Parameters in the PyTorch module get overridden with the values found in the ParamStore (if any). Defaults to `False` :type load_from_param_store: bool :returns: torch.nn.Module """ assert hasattr(nn_module, "parameters"), "module has no parameters" assert _MODULE_NAMESPACE_DIVIDER not in name, ( "improper module name, since contains %s" % _MODULE_NAMESPACE_DIVIDER) if isclass(nn_module): raise NotImplementedError( "pyro.module does not support class constructors for " + "the argument nn_module") target_state_dict = OrderedDict() for param_name, param_value in nn_module.named_parameters(): if param_value.requires_grad: # register the parameter in the module with pyro # this only does something substantive if the parameter hasn't been seen before full_param_name = param_with_module_name(name, param_name) returned_param = param(full_param_name, param_value) if param_value._cdata != returned_param._cdata: target_state_dict[param_name] = returned_param elif nn_module.training: warnings.warn( f"{param_name} was not registered in the param store " "because requires_grad=False. You can silence this " "warning by calling my_module.train(False)") if target_state_dict and update_module_params: # WARNING: this is very dangerous. better method? for _name, _param in nn_module.named_parameters(): is_param = False name_arr = _name.rsplit(".", 1) if len(name_arr) > 1: mod_name, param_name = name_arr[0], name_arr[1] else: is_param = True mod_name = _name if _name in target_state_dict.keys(): if not is_param: deep_getattr( nn_module, mod_name )._parameters[param_name] = target_state_dict[_name] else: nn_module._parameters[mod_name] = target_state_dict[_name] return nn_module
def __init__(self, name=None): super(Likelihood, self).__init__(name) self.y_name = (param_with_module_name(name, "y") if name is not None else "y")