def test_psd_safe_cholesky_psd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): for batch_mode in (False, True): if batch_mode: A = self._gen_test_psd().to(device=device, dtype=dtype) else: A = self._gen_test_psd()[0].to(device=device, dtype=dtype) idx = torch.arange(A.shape[-1], device=A.device) # default values Aprime = A.clone() Aprime[..., idx, idx] += 1e-6 if A.dtype == torch.float32 else 1e-8 L_exp = torch.cholesky(Aprime) with warnings.catch_warnings(record=True) as ws: L_safe = psd_safe_cholesky(A) self.assertEqual(len(ws), 1) self.assertEqual(ws[-1].category, RuntimeWarning) self.assertTrue(torch.allclose(L_exp, L_safe)) # user-defined value Aprime = A.clone() Aprime[..., idx, idx] += 1e-2 L_exp = torch.cholesky(Aprime) with warnings.catch_warnings(record=True) as ws: L_safe = psd_safe_cholesky(A, jitter=1e-2) self.assertEqual(len(ws), 1) self.assertEqual(ws[-1].category, RuntimeWarning) self.assertTrue(torch.allclose(L_exp, L_safe))
def test_psd_safe_cholesky_psd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): for batch_mode in (False, True): if batch_mode: A = self._gen_test_psd().to(device=device, dtype=dtype) else: A = self._gen_test_psd()[0].to(device=device, dtype=dtype) idx = torch.arange(A.shape[-1], device=A.device) # default values Aprime = A.clone() Aprime[..., idx, idx] += 1e-6 if A.dtype == torch.float32 else 1e-8 L_exp = torch.linalg.cholesky(Aprime) with warnings.catch_warnings(record=True) as w: # Makes sure warnings we catch don't cause `-w error` to fail warnings.simplefilter("always", NumericalWarning) L_safe = psd_safe_cholesky(A) self.assertTrue(any(issubclass(w_.category, NumericalWarning) for w_ in w)) self.assertTrue(any("A not p.d., added jitter" in str(w_.message) for w_ in w)) self.assertTrue(torch.allclose(L_exp, L_safe)) # user-defined value Aprime = A.clone() Aprime[..., idx, idx] += 1e-2 L_exp = torch.linalg.cholesky(Aprime) with warnings.catch_warnings(record=True) as w: # Makes sure warnings we catch don't cause `-w error` to fail warnings.simplefilter("always", NumericalWarning) L_safe = psd_safe_cholesky(A, jitter=1e-2) self.assertTrue(any(issubclass(w_.category, NumericalWarning) for w_ in w)) self.assertTrue(any("A not p.d., added jitter" in str(w_.message) for w_ in w)) self.assertTrue(torch.allclose(L_exp, L_safe))
def get_weights_posterior(X: Tensor, y: Tensor, sigma_sq: float) -> MultivariateNormal: r"""Sample bayesian linear regression weights. Args: X: a `n x num_rff_features`-dim tensor of inputs y: a `n`-dim tensor of outputs sigma_sq: the noise variance Returns: The posterior distribution over the weights. """ with torch.no_grad(): A = X.T @ X + sigma_sq * torch.eye(X.shape[-1], dtype=X.dtype, device=X.device) # mean is given by: m = S @ x.T @ y, where S = A_inv # compute inverse of A using solves # covariance is A_inv * sigma L_A = psd_safe_cholesky(A) # solve L_A @ u = I Iw = torch.eye(L_A.shape[0], dtype=X.dtype, device=X.device) u = torch.triangular_solve(Iw, L_A, upper=False).solution # solve L_A^T @ S = u A_inv = torch.triangular_solve(u, L_A.T).solution m = A_inv @ X.T @ y L = psd_safe_cholesky(A_inv * sigma_sq) return MultivariateNormal(loc=m, scale_tril=L)
def cholesky_safe(K, jitter=1e-6, max_tries=100): try: lower_chol = psd_safe_cholesky(K, jitter=jitter, max_tries=1) return lower_chol except RuntimeError: print("Not Choleskizable with jitter {}".format(jitter)) for i in range(max_tries): inc_jitter = 5**i * jitter print("Increasing jitter to {} and retrying.".format(inc_jitter)) try: lower_chol = psd_safe_cholesky(K, jitter=inc_jitter, max_tries=1) return lower_chol except RuntimeError: continue raise RuntimeError("Not Choleskizable at all.")
def sample(self, S, L, jitter=1e-5): """ Sample the GRF at generalized location (S, L). Parameters ---------- S: (M, d) Tensor List of spatial locations. L: (M) Tensor List of response indices. jitter: float Jitter to add if covariance matrix is not diagonalisable. Returns ------- Z: (M) Tensor The sampled value of Z_{s_i} component l_i. """ K = self.covariance.K(S, S, L, L) # chol = torch.cholesky(K) mu = self.mean(S, L) # Sample M independent N(0, 1) RVs. # TODO: Determine if this is better than doing Cholesky ourselves. lower_chol = psd_safe_cholesky(K, jitter=jitter) distr = MultivariateNormal( loc=mu, scale_tril=lower_chol) sample = distr.sample() #sample = mu + chol @ v return sample.float()
def log_marginal(self, Y, gauss_mean, gauss_cov, **kwargs): """ Computes the log marginal likelihood w.r.t the prior log p(y|x) = -1/2 (Y-mu)' @ (K+sigma²I)^{-1} @ (Y-mu) - 1/2 \log |K+sigma^2I| - N/2 log(2pi) Args: `Y` (torch.tensor) :->: Observations Y with shape (Dy,MB) `gauss_mean` (torch.tensor) :->: mean from p(f). Shape (Dy,MB) `gauss_cov` (torch.tensor) :->: full covariance from p(f). Shape (Dy,MB,MB) """ N = Y.size(1) Dy = self.out_dim # compute mean and covariance from the marginal distribution p(y|x). # This basically add the observation noise to the covariance mx,Kxx = self.marginal_moments(gauss_mean,gauss_cov, diagonal = False) # reshapes mx = mx.view(Dy,N,1) Y = Y.view(Dy,N,1) # solve using cholesky Y_mx = Y-mx Lxx = psd_safe_cholesky(Kxx, upper = False, jitter = cg.global_jitter) # Compute (Y-mu)' @ (K+sigma²I)^{-1} @ (Y-mu) rhs = torch.cholesky_solve(Y_mx, Lxx, upper = False) data_fit_term = torch.matmul(Y_mx.transpose(1,2),rhs) complexity_term = 2*torch.log(torch.diagonal(Lxx, dim1 = 1, dim2 = 2)).sum(1) cte = -N/2. * torch.log(2*cg.pi) return -0.5*(data_fit_term + complexity_term ) + cte
def test_pivoted_cholesky(self, max_iter=3): mat = self._create_mat().detach().requires_grad_(True) mat.register_hook(_ensure_symmetric_grad) mat_copy = mat.detach().clone().requires_grad_(True) mat_copy.register_hook(_ensure_symmetric_grad) # Forward (with function) res, pivots = pivoted_cholesky(mat, rank=max_iter, return_pivots=True) # Forward (manual pivoting, actual Cholesky) inverse_pivots = inverse_permutation(pivots) # Apply pivoting pivoted_mat_copy = apply_permutation(mat_copy, pivots, pivots) # Compute Cholesky actual_pivoted = psd_safe_cholesky(pivoted_mat_copy)[..., :max_iter] # Undo pivoting actual = apply_permutation(actual_pivoted, left_permutation=inverse_pivots) self.assertAllClose(res, actual) # Backward grad_output = torch.randn_like(res) res.backward(gradient=grad_output) actual.backward(gradient=grad_output) self.assertAllClose(mat.grad, mat_copy.grad)
def remove_inducing_points(self, idxs): """ Remove the inducing points corresponding to provided indices. """ mask = torch.ones(len(self.inducing_inputs), dtype=bool) mask[idxs] = torch.tensor(False) strat = self.variational_strategy dist = strat._variational_distribution # Point process probabilities raw_p = self.variational_point_process.raw_probabilities[mask] _reregister(self.variational_point_process, "raw_probabilities", raw_p) # Inducing inputs x_u = strat.inducing_points.data[mask] _reregister(strat, "inducing_points", x_u) # Inducing outputs def _chol(L): return psd_safe_cholesky(L @ L.T + I * 1e-8) I = torch.eye(len(x_u), device=x_u.device) m = dist.variational_mean L = dist.chol_variational_covar if m.dim() == 1: m = m.data[mask] elif m.dim() == 2: m = m.data[:, mask] else: raise NotImplementedError if L.dim() == 2: S = (L @ L.T)[mask][:, mask] L = psd_safe_cholesky(S + I * 1e-8) elif L.dim() == 3: S = (L @ L.transpose(1, 2))[:, mask][:, :, mask] L = psd_safe_cholesky(S + I * 1e-8) else: raise NotImplementedError _reregister(dist, "variational_mean", m) _reregister(dist, "chol_variational_covar", L) self.clear_caches()
def _update_covar(self) -> None: r"""Update values derived from the data and hyperparameters covar, covar_chol, and covar_inv will be of shape batch_shape x n x n """ self.covar = self._calc_covar(self.datapoints, self.datapoints) self.covar_chol = psd_safe_cholesky(self.covar) self.covar_inv = self._batch_chol_inv(self.covar_chol)
def test_psd_safe_cholesky_pd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") for dtype in (torch.float, torch.double): for batch_mode in (False, True): if batch_mode: A = self._gen_test_psd().to(device=device, dtype=dtype) D = torch.eye(2).type_as(A).unsqueeze(0).repeat(2, 1, 1) else: A = self._gen_test_psd()[0].to(device=device, dtype=dtype) D = torch.eye(2).type_as(A) A += D # basic L = torch.cholesky(A) L_safe = psd_safe_cholesky(A) self.assertTrue(torch.allclose(L, L_safe)) # upper L = torch.cholesky(A, upper=True) L_safe = psd_safe_cholesky(A, upper=True) self.assertTrue(torch.allclose(L, L_safe)) # output tensors L = torch.empty_like(A) L_safe = torch.empty_like(A) torch.cholesky(A, out=L) psd_safe_cholesky(A, out=L_safe) self.assertTrue(torch.allclose(L, L_safe)) # output tensors, upper torch.cholesky(A, upper=True, out=L) psd_safe_cholesky(A, upper=True, out=L_safe) self.assertTrue(torch.allclose(L, L_safe)) # make sure jitter doesn't do anything if p.d. L = torch.cholesky(A) L_safe = psd_safe_cholesky(A, jitter=1e-2) self.assertTrue(torch.allclose(L, L_safe))
def update_variational_distribution(self, x_new, y_new): m_b, S_b = self._update_variational_moments(x_new, y_new) q_mean = self.variational_strategy._variational_distribution.variational_mean q_mean.data.copy_(m_b.squeeze(-1)) upper_new_covar = psd_safe_cholesky(S_b, jitter=self._jitter) upper_q_covar = self.variational_strategy._variational_distribution.chol_variational_covar upper_q_covar.copy_(upper_new_covar) self.variational_strategy.variational_params_initialized.fill_(1)
def _update_covar(self, datapoints: Tensor) -> None: r"""Update values derived from the data and hyperparameters covar, covar_chol, and covar_inv will be of shape batch_shape x n x n Args: datapoints: (Transformed) datapoints for finding f_max """ self.covar = self._calc_covar(datapoints, datapoints) self.covar_chol = psd_safe_cholesky(self.covar) self.covar_inv = self._batch_chol_inv(self.covar_chol)
def sample(self, num_context, num_target): r""" Args: num_context (int): Number of context points at the sample. num_target (int): Number of target points at the sample. Returns: :class:`Tensor`. Different between train mode and test mode: *`train`: `num_context x x_dim`, `num_context x y_dim`, `num_total x x_dim`, `num_total x y_dim` *`test`: `num_context x x_dim`, `num_context x y_dim`, `400 x x_dim`, `400 x y_dim` """ if self.train: num_total = num_context + num_target x_values = torch.empty(num_total, self.x_dim).uniform_(*self.data_range) else: lower, upper = self.data_range num_total = int((upper - lower) / 0.01 + 1) x_values = torch.linspace(self.data_range[0], self.data_range[1], num_total).unsqueeze(-1) if self.random_params: length_scale = torch.empty(self.y_dim, self.x_dim).uniform_( 0.1, self.length_scale) # [y, x] output_scale = torch.empty(self.y_dim).uniform_( 0.1, self.output_scale) # [y] else: length_scale = torch.full((self.y_dim, self.x_dim), self.length_scale) output_scale = torch.full((self.y_dim, ), self.output_scale) # [y_dim, num_total, num_total] covariance = self.kernel(x_values, length_scale, output_scale) cholesky = psd_safe_cholesky(covariance) # [num_total, num_total] x [] = [] y_values = cholesky.matmul(torch.randn(self.y_dim, num_total, 1)).squeeze(2).transpose(0, 1) if self.train: context_x = x_values[:num_context, :] context_y = y_values[:num_context, :] else: idx = torch.randperm(num_total) context_x = torch.gather(x_values, 0, idx[:num_context].unsqueeze(-1)) context_y = torch.gather(y_values, 0, idx[:num_context].unsqueeze(-1)) return context_x, context_y, x_values, y_values
def _update_variational_moments(self, x, y): C = self.current_C_matrix(x) c = self.current_c_vec(x, y) z_b = self.variational_strategy.inducing_points Kbb = self.covar_module(z_b).evaluate() L = psd_safe_cholesky(Kbb + C.evaluate(), upper=False, jitter=self._jitter) m_b = Kbb @ torch.cholesky_solve(c, L, upper=False) S_b = Kbb @ torch.cholesky_solve(Kbb, L, upper=False) return m_b, S_b
def _update_inducing_points(gp, Z, m, S=None, L=None): # Setting the inducing points for a given layer assert (S is None) ^ (L is None) strat = gp.variational_strategy dist = strat._variational_distribution strat.inducing_points.data = Z dist.variational_mean.data = m if S is not None: L = psd_safe_cholesky(S) dist.chol_variational_covar.data = L gp.clear_caches()
def _inducing_inv_root(self): if not self.training and hasattr(self, "_cached_kernel_inv_root"): return self._cached_kernel_inv_root else: chol = psd_safe_cholesky(self._inducing_mat, upper=True) eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype) inv_root = torch.triangular_solve(eye, chol)[0] res = inv_root if not self.training: self._cached_kernel_inv_root = res return res
def initialize_variational_dist(self): """Initialize variational distribution. Describes what distribution to pass to the VariationalDistribution to initialize with. Most commonly, this should be the prior distribution for the inducing points, N(m_u, K_uu). However, if a subclass assumes a different parameterization of the variational distribution, it may need to modify what the prior is with respect to that reparameterization. """ prior_dist = self.prior_distribution eval_prior_dist = torch.distributions.MultivariateNormal( loc=prior_dist.mean, scale_tril=psd_safe_cholesky(prior_dist.covariance_matrix), ) self.variational_distribution.initialize_variational_distribution( eval_prior_dist)
def update_distribution(self, model, X, Y, X_u, likelihood): n_u = len(X_u) v = likelihood.noise_covar.noise b = 1 / v full_inputs = torch.cat([X_u, X], dim=-2) full_output = model.forward(full_inputs) full_covar = full_output.lazy_covariance_matrix K_MM = full_covar[..., :n_u, :n_u].add_jitter() K_MN = full_covar[..., :n_u, n_u:].evaluate() K_MM_inv = torch.inverse(K_MM.evaluate()) a = K_MM.inv_matmul(K_MN) inner = a.matmul(a.T) A = b * inner + K_MM_inv S = torch.inverse(A) m = b * S @ K_MM_inv @ K_MN @ Y self.variational_mean = m.flatten() self.chol_variational_covar = psd_safe_cholesky(S)
def sample(self, jitter=1e-5): """ Sample the discretized GRF on the whole grid. Returns ------- Z: (M) Tensor The sampled value of Z_{s_i} component l_i. jitter: float Jitter to add if covariance matrix is not diagonalisable. """ K = self.covariance_mat.list mu = self.mean_vec.list # Sample M independent N(0, 1) RVs. # TODO: Determine if this is better than doing Cholesky ourselves. lower_chol = psd_safe_cholesky(K, jitter=jitter) distr = MultivariateNormal(loc=mu, scale_tril=lower_chol) sample = distr.sample() return GeneralizedVector.from_list(sample.float(), self.n_points, self.n_out)
def current_C_matrix(self, x): sigma2 = self.likelihood.noise z_b = self.variational_strategy.inducing_points Kbf = self.covar_module(z_b, x).evaluate() C1 = Kbf @ Kbf.transpose(-1, -2) / sigma2 if self._old_C_matrix is None: C2 = torch.zeros_like(C1) else: assert self._old_strat is not None assert self._old_kernel is not None z_a = self._old_strat.inducing_points.detach() Kaa_old = self._old_kernel(z_a).add_jitter(self._jitter).detach() C_old = self._old_C_matrix.detach() Kab = self.covar_module(z_a, z_b).evaluate() Kaa_old_inv_Kab = Kaa_old.inv_matmul(Kab) C2 = Kaa_old_inv_Kab.transpose(-1, -2) @ C_old.matmul(Kaa_old_inv_Kab) C = C1 + C2 L = psd_safe_cholesky(C, upper=False, jitter=self._jitter) L = lazy.TriangularLazyTensor(L, upper=False) return lazy.CholLazyTensor(L, upper=False)
# Sample and plot # ------------------------------------------------------ # Sample all components at all locations. sample = my_discrete_grf.sample() plot_grid_values(my_grid, sample) # From now on, we will consider the drawn sample as ground truth. # --------------------------------------------------------------- ground_truth = sample # Save for reproducibility. np.save("ground_truth.npy", ground_truth.numpy()) # Use it to declare the data feed. noise_std = torch.tensor([0.1, 0.1]) # Noise distribution lower_chol = psd_safe_cholesky(torch.diag(noise_std**2)) noise_distr = MultivariateNormal(loc=torch.zeros(n_out), scale_tril=lower_chol) def data_feed(node_ind): noise_realization = noise_distr.sample() return ground_truth[node_ind] + noise_realization my_sensor = DiscreteSensor(my_discrete_grf) # Excursion threshold. lower = torch.tensor([2.3, 22.0]).float() # Get the real excursion set and plot it. excursion_ground_truth = (sample.isotopic > lower).float()
def __call__(self, x, y): sigma2 = self.gp.likelihood.noise z_b = self.gp.variational_strategy.inducing_points Kff = self.gp.covar_module(x).evaluate() Kbf = self.gp.covar_module(z_b, x).evaluate() Kbb = self.gp.covar_module(z_b).add_jitter(self.gp._jitter) Q1 = Kbf.transpose(-1, -2) @ Kbb.inv_matmul(Kbf) Sigma1 = sigma2 * torch.eye(Q1.size(-1)).to(Q1.device) # logp term if self.gp._old_strat is None: num_data = y.size(-2) mean = torch.zeros(num_data).to(y.device) covar = (Q1 + Sigma1) + self.gp._jitter * torch.eye( Q1.size(-2)).to(Q1.device) dist = distributions.MultivariateNormal(mean, covar) logp_term = dist.log_prob(y.squeeze(-1)).sum() / y.size(-2) else: z_a = self.gp._old_strat.inducing_points.detach() Kba = self.gp.covar_module(z_b, z_a).evaluate() Kaa_old = self.gp._old_kernel(z_a).evaluate().detach() Q2 = Kba.transpose(-1, -2) @ Kbb.inv_matmul(Kba) zero_1 = torch.zeros(Q1.size(-2), Q2.size(-1)).to(Q1.device) zero_2 = torch.zeros(Q2.size(-2), Q1.size(-1)).to(Q1.device) Q = torch.cat([ torch.cat([Q1, zero_1], dim=-1), torch.cat([zero_2, Q2], dim=-1) ], dim=-2) C_old = self.gp._old_C_matrix.detach() Sigma2 = Kaa_old @ C_old.inv_matmul(Kaa_old) Sigma2 = Sigma2 + self.gp._jitter * torch.eye(Sigma2.size(-2)).to( Sigma2.device) Sigma = torch.cat([ torch.cat([Sigma1, zero_1], dim=-1), torch.cat([zero_2, Sigma2], dim=-1) ], dim=-2) y_hat = torch.cat([y, self.gp.pseudotargets]) mean = torch.zeros_like(y_hat.squeeze(-1)) covar = (Q + Sigma) + self.gp._jitter * torch.eye(Q.size(-2)).to( Q.device) dist = distributions.MultivariateNormal(mean, covar) logp_term = dist.log_prob(y_hat.squeeze(-1)).sum() / y_hat.size(-2) num_data = y_hat.size(-2) # trace term t1 = (Kff - Q1).diag().sum() / sigma2 t2 = 0 if self.gp._old_strat is not None: LSigma2 = psd_safe_cholesky(Sigma2, upper=False, jitter=self.gp._jitter) Kaa = self.gp.covar_module(z_a).evaluate().detach() Sigma2_inv_Kaa = torch.cholesky_solve(Kaa, LSigma2, upper=False) Sigma2_inv_Q2 = torch.cholesky_solve(Q2, LSigma2, upper=False) t2 = Sigma2_inv_Kaa.diag().sum() - Sigma2_inv_Q2.diag().sum() trace_term = -(t1 + t2) / 2 / num_data if self._combine_terms: return logp_term + trace_term else: return logp_term, trace_term, t1 / num_data, t2 / num_data
def _chol(L): return psd_safe_cholesky(L @ L.T + I * 1e-8)
def _compute_information_gain(self, X: Tensor, mean_M: Tensor, variance_M: Tensor, covar_mM: Tensor) -> Tensor: r"""Computes the information gain at the design points `X`. Approximately computes the information gain at the design points `X`, for both MES with noisy observations and multi-fidelity MES with noisy observation and trace observations. The implementation is inspired from the paper on multi-fidelity MES by Takeno et. al. [Takeno2019mfmves]_. The notations in the comments in this function follows the Appendix A in the paper. Args: X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches with `1` `d`-dim design point each. mean_M, variance_M: `batch_shape x num_fantasies`-dim Tensors of `batch_shape` t-batches with `num_fantasies` fantasies. `num_fantasies = 1` for non-fantasized models. All are obtained without noise. covar_mM: `batch_shape x num_fantasies x (1 + num_trace_observations)` -dim Tensor. `num_fantasies = 1` for non-fantasized models. All are obtained without noise. Returns: A `num_fantasies x batch_shape`-dim Tensor of information gains at the given design points `X`. """ # compute the std_m, variance_m with noisy observation posterior_m = self.model.posterior(X.unsqueeze(-3), observation_noise=True) mean_m = self.weight * posterior_m.mean.squeeze(-1) # batch_shape x num_fantasies x (1 + num_trace_observations) variance_m = posterior_m.mvn.covariance_matrix # batch_shape x num_fantasies x (1 + num_trace_observations)^2 check_no_nans(variance_m) # compute mean and std for fM|ym, x, Dt ~ N(u, s^2) samples_m = self.weight * self.sampler(posterior_m).squeeze(-1) # s_m x batch_shape x num_fantasies x (1 + num_trace_observations) L = psd_safe_cholesky(variance_m) temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1) # equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m)) # batch_shape x num_fantasies x 1 x (1 + num_trace_observations) mean_pt1 = torch.matmul(temp_term, (samples_m - mean_m).unsqueeze(-1)) mean_new = mean_pt1.squeeze(-1).squeeze(-1) + mean_M # s_m x batch_shape x num_fantasies variance_pt1 = torch.matmul(temp_term, covar_mM.unsqueeze(-1)) variance_new = variance_M - variance_pt1.squeeze(-1).squeeze(-1) # batch_shape x num_fantasies stdv_new = variance_new.clamp_min(CLAMP_LB).sqrt() # batch_shape x num_fantasies # define normal distribution to compute cdf and pdf normal = torch.distributions.Normal( torch.zeros(1, device=X.device, dtype=X.dtype), torch.ones(1, device=X.device, dtype=X.dtype), ) # Compute p(fM <= f* | ym, x, Dt) view_shape = ([self.num_mv_samples] + [1] * (len(X.shape) - 2) + [self.num_fantasies] ) # s_M x batch_shape x num_fantasies if self.X_pending is None: view_shape[-1] = 1 max_vals = self.posterior_max_values.view(view_shape).unsqueeze(1) # s_M x 1 x batch_shape x num_fantasies normalized_mvs_new = (max_vals - mean_new) / stdv_new # s_M x s_m x batch_shape x num_fantasies = # s_M x 1 x batch_shape x num_fantasies - s_m x batch_shape x num_fantasies cdf_mvs_new = normal.cdf(normalized_mvs_new).clamp_min(CLAMP_LB) # Compute p(fM <= f* | x, Dt) stdv_M = variance_M.sqrt() normalized_mvs = (max_vals - mean_M) / stdv_M # s_M x 1 x batch_shape x num_fantasies = # s_M x 1 x 1 x num_fantasies - batch_shape x num_fantasies cdf_mvs = normal.cdf(normalized_mvs).clamp_min(CLAMP_LB) # s_M x 1 x batch_shape x num_fantasies # Compute log(p(ym | x, Dt)) log_pdf_fm = posterior_m.mvn.log_prob(self.weight * samples_m).unsqueeze(0) # 1 x s_m x batch_shape x num_fantasies # H0 = H(ym | x, Dt) H0 = posterior_m.mvn.entropy() # batch_shape x num_fantasies # regression adjusted H1 estimation, H1_hat = H1_bar - beta * (H0_bar - H0) # H1 = E_{f*|x, Dt}[H(ym|f*, x, Dt)] Z = cdf_mvs_new / cdf_mvs # s_M x s_m x batch_shape x num_fantasies h1 = -Z * Z.log( ) - Z * log_pdf_fm # s_M x s_m x batch_shape x num_fantasies check_no_nans(h1) dim = [0, 1] # dimension of fm samples, fM samples H1_bar = h1.mean(dim=dim) h0 = -log_pdf_fm H0_bar = h0.mean(dim=dim) cov = ((h1 - H1_bar) * (h0 - H0_bar)).mean(dim=dim) beta = cov / (h0.var(dim=dim) * h1.var(dim=dim)).sqrt() H1_hat = H1_bar - beta * (H0_bar - H0) ig = H0 - H1_hat # batch_shape x num_fantasies ig = ig.permute(-1, *range(ig.dim() - 1)) # num_fantasies x batch_shape return ig
def update_variational_parameters(self, new_x, new_y, new_inducing_points=None): # if new_inducing_points = None # this version of the variational update does NOT assume that the inducing points # are moving around as we add new data. because we are using gradient based updates # to optimize them, we do not compute any type of randomized updates to them as in # bui et al. if new_inducing_points is None: new_inducing_points = self.variational_strategy.inducing_points.detach( ).clone() self.set_streaming(True) with torch.no_grad(): # self.register_streaming_loss(self.variational_strategy.variational_distribution) self.register_streaming_loss() if len(new_y.shape) == 1: new_y = new_y.view(-1, 1) S_a = self.variational_strategy.variational_distribution.lazy_covariance_matrix K_aa_old = self.variational_strategy.prior_distribution.lazy_covariance_matrix m_a = self.variational_strategy.variational_distribution.mean D_a_inv = (S_a.evaluate().inverse() - K_aa_old.evaluate().inverse()) # compute D S_a^{-1} m_a pseudo_points = torch.solve( S_a.inv_matmul(m_a).unsqueeze(-1), D_a_inv)[0] # stack y and the pseudo points hat_y = torch.cat((new_y.view(-1, 1), pseudo_points)) # we now create Sigma_\hat y = blockdiag(\sigma^2 I; D_a) noise_diag = self.likelihood.noise * torch.eye(new_y.size(-2)).to( new_y.device) zero_part = torch.zeros(new_y.size(-2), pseudo_points.size(0)).to(new_y.device) tophalf = torch.cat((noise_diag, zero_part), -1) bottomhalf = torch.cat((zero_part.t(), D_a_inv.inverse()), -1) sigma_hat_y = torch.cat((tophalf, bottomhalf)) # stack the data to be able to compute covariances with it # (x, a) stacked_data = torch.cat( (new_x, self.variational_strategy.inducing_points)) K_fb = self.covar_module(stacked_data, new_inducing_points) K_bb = self.covar_module(new_inducing_points) # C = K_{hat f b} K_{bb}^{-1} K_{b hat f} + Sigma_\hat y pred_cov = K_fb @ (K_bb.inv_matmul( K_fb.evaluate().t())) + sigma_hat_y # the new mean is K_{hat f b} C^{-1} \hat y new_mean = K_fb.t() @ torch.solve( hat_y, pred_cov)[0].squeeze(-1).detach().contiguous() # the new covariance is K_bb - K_{hat f b} C^{-1} K_{b hat f} new_cov = K_bb - K_fb.t() @ torch.solve(K_fb.evaluate(), pred_cov)[0] new_variational_chol = psd_safe_cholesky( new_cov.evaluate(), jitter=cholesky_jitter.value()).detach().contiguous() self.variational_strategy._variational_distribution.variational_mean.data.mul_( 0.).add_(new_mean) self.variational_strategy._variational_distribution.chol_variational_covar.data.mul_( 0.).add_(new_variational_chol) self.variational_strategy.inducing_points.data.mul_(0.).add_( new_inducing_points.detach())
def test_psd_safe_cholesky_nan(self, cuda=False): A = self._gen_test_psd().sqrt() with self.assertRaises(NanError) as ctx: psd_safe_cholesky(A) self.assertTrue("NaN" in ctx.exception)
def _compute_information_gain(self, X: Tensor, mean_M: Tensor, variance_M: Tensor, covar_mM: Tensor) -> Tensor: r"""Computes the information gain at the design points `X`. Approximately computes the information gain at the design points `X`, for both MES with noisy observations and multi-fidelity MES with noisy observation and trace observations. The implementation is inspired from the papers on multi-fidelity MES by [Takeno2020mfmves]_. The notation in the comments in this function follows the Appendix C of [Takeno2020mfmves]_. `num_fantasies = 1` for non-fantasized models. Args: X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches with `1` `d`-dim design point each. mean_M: A `batch_shape x num_fantasies x (m)`-dim Tensor of means. variance_M: A `batch_shape x num_fantasies x (m)`-dim Tensor of variances. covar_mM: A `batch_shape x num_fantasies x (m) x (1 + num_trace_observations)`-dim Tensor of covariances. Returns: A `num_fantasies x batch_shape`-dim Tensor of information gains at the given design points `X` (`num_fantasies=1` for non-fantasized models). """ # compute the std_m, variance_m with noisy observation posterior_m = self.model.posterior(X.unsqueeze(-3), observation_noise=True) # batch_shape x num_fantasies x (m) x (1 + num_trace_observations) mean_m = self.weight * posterior_m.mean.squeeze(-1) # batch_shape x num_fantasies x (m) x (1 + num_trace_observations) variance_m = posterior_m.mvn.covariance_matrix check_no_nans(variance_m) # compute mean and std for fM|ym, x, Dt ~ N(u, s^2) samples_m = self.weight * self.sampler(posterior_m).squeeze(-1) # s_m x batch_shape x num_fantasies x (m) (1 + num_trace_observations) L = psd_safe_cholesky(variance_m) temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1) # equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m)) # batch_shape x num_fantasies (m) x 1 x (1 + num_trace_observations) mean_pt1 = torch.matmul(temp_term, (samples_m - mean_m).unsqueeze(-1)) mean_new = mean_pt1.squeeze(-1).squeeze(-1) + mean_M # s_m x batch_shape x num_fantasies x (m) variance_pt1 = torch.matmul(temp_term, covar_mM.unsqueeze(-1)) variance_new = variance_M - variance_pt1.squeeze(-1).squeeze(-1) # batch_shape x num_fantasies x (m) stdv_new = variance_new.clamp_min(CLAMP_LB).sqrt() # batch_shape x num_fantasies x (m) # define normal distribution to compute cdf and pdf normal = torch.distributions.Normal( torch.zeros(1, device=X.device, dtype=X.dtype), torch.ones(1, device=X.device, dtype=X.dtype), ) # Compute p(fM <= f* | ym, x, Dt) view_shape = torch.Size([ self.posterior_max_values.shape[0], # add 1s to broadcast across the batch_shape of X *[1 for _ in range(X.ndim - self.posterior_max_values.ndim)], *self.posterior_max_values.shape[1:], ]) # s_M x batch_shape x num_fantasies x (m) max_vals = self.posterior_max_values.view(view_shape).unsqueeze(1) # s_M x 1 x batch_shape x num_fantasies x (m) normalized_mvs_new = (max_vals - mean_new) / stdv_new # s_M x s_m x batch_shape x num_fantasies x (m) = # s_M x 1 x batch_shape x num_fantasies x (m) # - s_m x batch_shape x num_fantasies x (m) cdf_mvs_new = normal.cdf(normalized_mvs_new).clamp_min(CLAMP_LB) # Compute p(fM <= f* | x, Dt) stdv_M = variance_M.sqrt() normalized_mvs = (max_vals - mean_M) / stdv_M # s_M x 1 x batch_shape x num_fantasies x (m) = # s_M x 1 x 1 x num_fantasies x (m) - batch_shape x num_fantasies x (m) cdf_mvs = normal.cdf(normalized_mvs).clamp_min(CLAMP_LB) # s_M x 1 x batch_shape x num_fantasies x (m) # Compute log(p(ym | x, Dt)) log_pdf_fm = posterior_m.mvn.log_prob(self.weight * samples_m).unsqueeze(0) # 1 x s_m x batch_shape x num_fantasies x (m) # H0 = H(ym | x, Dt) H0 = posterior_m.mvn.entropy() # batch_shape x num_fantasies x (m) # regression adjusted H1 estimation, H1_hat = H1_bar - beta * (H0_bar - H0) # H1 = E_{f*|x, Dt}[H(ym|f*, x, Dt)] Z = cdf_mvs_new / cdf_mvs # s_M x s_m x batch_shape x num_fantasies x (m) # s_M x s_m x batch_shape x num_fantasies x (m) h1 = -Z * Z.log() - Z * log_pdf_fm check_no_nans(h1) dim = [0, 1] # dimension of fm samples, fM samples H1_bar = h1.mean(dim=dim) h0 = -log_pdf_fm H0_bar = h0.mean(dim=dim) cov = ((h1 - H1_bar) * (h0 - H0_bar)).mean(dim=dim) beta = cov / (h0.var(dim=dim) * h1.var(dim=dim)).sqrt() H1_hat = H1_bar - beta * (H0_bar - H0) ig = H0 - H1_hat # batch_shape x num_fantasies x (m) if self.posterior_max_values.ndim == 2: permute_idcs = [-1, *range(ig.ndim - 1)] else: permute_idcs = [-2, *range(ig.ndim - 2), -1] ig = ig.permute(*permute_idcs) # num_fantasies x batch_shape x (m) return ig
def sample_cached_cholesky( posterior: GPyTorchPosterior, baseline_L: Tensor, q: int, base_samples: Tensor, sample_shape: torch.Size, max_tries: int = 6, ) -> Tensor: r"""Get posterior samples at the `q` new points from the joint multi-output posterior. Args: posterior: The joint posterior is over (X_baseline, X). baseline_L: The baseline lower triangular cholesky factor. q: The number of new points in X. base_samples: The base samples. sample_shape: The sample shape. max_tries: The number of tries for computing the Cholesky decomposition with increasing jitter. Returns: A `sample_shape x batch_shape x q x m`-dim tensor of posterior samples at the new points. """ # compute bottom left covariance block if isinstance(posterior.mvn, MultitaskMultivariateNormal): lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn) else: lazy_covar = posterior.mvn.lazy_covariance_matrix # Get the `q` new rows of the batched covariance matrix bottom_rows = lazy_covar[..., -q:, :].evaluate() # The covariance in block form is: # [K(X_baseline, X_baseline), K(X_baseline, X)] # [K(X, X_baseline), K(X, X)] # bl := K(X, X_baseline) # br := K(X, X) # Get bottom right block of new covariance bl, br = bottom_rows.split([bottom_rows.shape[-1] - q, q], dim=-1) # Solve Ax = b # where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T # and bl_chol := x^T # bl_chol is the new `(batch_shape) x q x n`-dim bottom left block # of the cholesky decomposition bl_chol = torch.triangular_solve(bl.transpose(-2, -1), baseline_L, upper=False).solution.transpose(-2, -1) # Compute the new bottom right block of the Cholesky # decomposition via: # Cholesky(K(X, X) - bl_chol @ bl_chol^T) br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1) # TODO: technically we should make sure that we add a # consistent nugget to the cached covariance and the new block br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries) # Create a `(batch_shape) x q x (n+q)`-dim tensor containing the # `q` new bottom rows of the Cholesky decomposition new_Lq = torch.cat([bl_chol, br_chol], dim=-1) mean = posterior.mvn.mean base_samples = _reshape_base_samples( base_samples=base_samples, sample_shape=sample_shape, posterior=posterior, ) if not isinstance(posterior.mvn, MultitaskMultivariateNormal): # add output dim mean = mean.unsqueeze(-1) # add batch dim corresponding to output dim new_Lq = new_Lq.unsqueeze(-3) new_mean = mean[..., -q:, :] res = (new_Lq.matmul(base_samples).add( new_mean.transpose(-1, -2).unsqueeze(-1)).permute( -1, *range(posterior.mvn.loc.dim() - 1), -2, -3).contiguous()) contains_nans = torch.isnan(res).any() contains_infs = torch.isinf(res).any() if contains_nans or contains_infs: suffix_args = [] if contains_nans: suffix_args.append("nans") if contains_infs: suffix_args.append("infs") suffix = " and ".join(suffix_args) raise NanError(f"Samples contain {suffix}.") return res