Ejemplo n.º 1
0
 def test_psd_safe_cholesky_psd(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         for batch_mode in (False, True):
             if batch_mode:
                 A = self._gen_test_psd().to(device=device, dtype=dtype)
             else:
                 A = self._gen_test_psd()[0].to(device=device, dtype=dtype)
             idx = torch.arange(A.shape[-1], device=A.device)
             # default values
             Aprime = A.clone()
             Aprime[..., idx,
                    idx] += 1e-6 if A.dtype == torch.float32 else 1e-8
             L_exp = torch.cholesky(Aprime)
             with warnings.catch_warnings(record=True) as ws:
                 L_safe = psd_safe_cholesky(A)
                 self.assertEqual(len(ws), 1)
                 self.assertEqual(ws[-1].category, RuntimeWarning)
             self.assertTrue(torch.allclose(L_exp, L_safe))
             # user-defined value
             Aprime = A.clone()
             Aprime[..., idx, idx] += 1e-2
             L_exp = torch.cholesky(Aprime)
             with warnings.catch_warnings(record=True) as ws:
                 L_safe = psd_safe_cholesky(A, jitter=1e-2)
                 self.assertEqual(len(ws), 1)
                 self.assertEqual(ws[-1].category, RuntimeWarning)
             self.assertTrue(torch.allclose(L_exp, L_safe))
Ejemplo n.º 2
0
    def test_psd_safe_cholesky_psd(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        for dtype in (torch.float, torch.double):
            for batch_mode in (False, True):
                if batch_mode:
                    A = self._gen_test_psd().to(device=device, dtype=dtype)
                else:
                    A = self._gen_test_psd()[0].to(device=device, dtype=dtype)
                idx = torch.arange(A.shape[-1], device=A.device)
                # default values
                Aprime = A.clone()
                Aprime[..., idx, idx] += 1e-6 if A.dtype == torch.float32 else 1e-8
                L_exp = torch.linalg.cholesky(Aprime)
                with warnings.catch_warnings(record=True) as w:
                    # Makes sure warnings we catch don't cause `-w error` to fail
                    warnings.simplefilter("always", NumericalWarning)

                    L_safe = psd_safe_cholesky(A)
                    self.assertTrue(any(issubclass(w_.category, NumericalWarning) for w_ in w))
                    self.assertTrue(any("A not p.d., added jitter" in str(w_.message) for w_ in w))
                self.assertTrue(torch.allclose(L_exp, L_safe))
                # user-defined value
                Aprime = A.clone()
                Aprime[..., idx, idx] += 1e-2
                L_exp = torch.linalg.cholesky(Aprime)
                with warnings.catch_warnings(record=True) as w:
                    # Makes sure warnings we catch don't cause `-w error` to fail
                    warnings.simplefilter("always", NumericalWarning)

                    L_safe = psd_safe_cholesky(A, jitter=1e-2)
                    self.assertTrue(any(issubclass(w_.category, NumericalWarning) for w_ in w))
                    self.assertTrue(any("A not p.d., added jitter" in str(w_.message) for w_ in w))
                self.assertTrue(torch.allclose(L_exp, L_safe))
Ejemplo n.º 3
0
def get_weights_posterior(X: Tensor, y: Tensor, sigma_sq: float) -> MultivariateNormal:
    r"""Sample bayesian linear regression weights.

    Args:
        X: a `n x num_rff_features`-dim tensor of inputs
        y: a `n`-dim tensor of outputs
        sigma_sq: the noise variance

    Returns:
        The posterior distribution over the weights.
    """
    with torch.no_grad():
        A = X.T @ X + sigma_sq * torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
        # mean is given by: m = S @ x.T @ y, where S = A_inv
        # compute inverse of A using solves
        # covariance is A_inv * sigma
        L_A = psd_safe_cholesky(A)
        # solve L_A @ u = I
        Iw = torch.eye(L_A.shape[0], dtype=X.dtype, device=X.device)
        u = torch.triangular_solve(Iw, L_A, upper=False).solution
        # solve L_A^T @ S = u
        A_inv = torch.triangular_solve(u, L_A.T).solution
        m = A_inv @ X.T @ y
        L = psd_safe_cholesky(A_inv * sigma_sq)
        return MultivariateNormal(loc=m, scale_tril=L)
Ejemplo n.º 4
0
 def cholesky_safe(K, jitter=1e-6, max_tries=100):
     try:
         lower_chol = psd_safe_cholesky(K, jitter=jitter, max_tries=1)
         return lower_chol
     except RuntimeError:
         print("Not Choleskizable with jitter {}".format(jitter))
     for i in range(max_tries):
         inc_jitter = 5**i * jitter
         print("Increasing jitter to {} and retrying.".format(inc_jitter))
         try:
             lower_chol = psd_safe_cholesky(K, jitter=inc_jitter, max_tries=1)
             return lower_chol
         except RuntimeError: continue
     raise RuntimeError("Not Choleskizable at all.")
Ejemplo n.º 5
0
    def sample(self, S, L, jitter=1e-5):
        """ Sample the GRF at generalized location (S, L).

        Parameters
        ----------
        S: (M, d) Tensor
            List of spatial locations.
        L: (M) Tensor
            List of response indices.
        jitter: float
            Jitter to add if covariance matrix is not diagonalisable.

        Returns
        -------
        Z: (M) Tensor
            The sampled value of Z_{s_i} component l_i.

        """
        K = self.covariance.K(S, S, L, L)
        # chol = torch.cholesky(K)
        mu = self.mean(S, L)

        # Sample M independent N(0, 1) RVs.
        # TODO: Determine if this is better than doing Cholesky ourselves.
        lower_chol = psd_safe_cholesky(K, jitter=jitter)
        distr = MultivariateNormal(
                loc=mu,
                scale_tril=lower_chol)
        sample = distr.sample()

        #sample = mu + chol @ v 

        return sample.float()
Ejemplo n.º 6
0
    def log_marginal(self, Y, gauss_mean, gauss_cov, **kwargs):
        """ Computes the log marginal likelihood w.r.t the prior
        
            log p(y|x) = -1/2 (Y-mu)' @ (K+sigma²I)^{-1} @ (Y-mu) - 1/2 \log |K+sigma^2I| - N/2 log(2pi)         

            Args:
                `Y` (torch.tensor)  :->: Observations Y with shape (Dy,MB)
                `gauss_mean` (torch.tensor)  :->:  mean from p(f). Shape (Dy,MB)
                `gauss_cov`  (torch.tensor)  :->:  full covariance from p(f). Shape (Dy,MB,MB)
        """

        N = Y.size(1)
        Dy = self.out_dim
    
        # compute mean and covariance from the marginal distribution p(y|x).
        # This basically add the observation noise to the covariance
        mx,Kxx = self.marginal_moments(gauss_mean,gauss_cov, diagonal = False)

        # reshapes
        mx = mx.view(Dy,N,1)
        Y  = Y.view(Dy,N,1)

        # solve using cholesky
        Y_mx = Y-mx
        Lxx = psd_safe_cholesky(Kxx, upper = False, jitter = cg.global_jitter)

        # Compute (Y-mu)' @ (K+sigma²I)^{-1} @ (Y-mu)
        rhs = torch.cholesky_solve(Y_mx, Lxx, upper = False)

        data_fit_term   = torch.matmul(Y_mx.transpose(1,2),rhs)
        complexity_term = 2*torch.log(torch.diagonal(Lxx, dim1 = 1, dim2 = 2)).sum(1) 
        cte      = -N/2. * torch.log(2*cg.pi)
        

        return -0.5*(data_fit_term + complexity_term ) + cte
    def test_pivoted_cholesky(self, max_iter=3):
        mat = self._create_mat().detach().requires_grad_(True)
        mat.register_hook(_ensure_symmetric_grad)
        mat_copy = mat.detach().clone().requires_grad_(True)
        mat_copy.register_hook(_ensure_symmetric_grad)

        # Forward (with function)
        res, pivots = pivoted_cholesky(mat, rank=max_iter, return_pivots=True)

        # Forward (manual pivoting, actual Cholesky)
        inverse_pivots = inverse_permutation(pivots)
        # Apply pivoting
        pivoted_mat_copy = apply_permutation(mat_copy, pivots, pivots)
        # Compute Cholesky
        actual_pivoted = psd_safe_cholesky(pivoted_mat_copy)[..., :max_iter]
        # Undo pivoting
        actual = apply_permutation(actual_pivoted,
                                   left_permutation=inverse_pivots)

        self.assertAllClose(res, actual)

        # Backward
        grad_output = torch.randn_like(res)
        res.backward(gradient=grad_output)
        actual.backward(gradient=grad_output)
        self.assertAllClose(mat.grad, mat_copy.grad)
Ejemplo n.º 8
0
    def remove_inducing_points(self, idxs):
        """
        Remove the inducing points corresponding to provided indices.
        """

        mask = torch.ones(len(self.inducing_inputs), dtype=bool)
        mask[idxs] = torch.tensor(False)

        strat = self.variational_strategy
        dist = strat._variational_distribution

        # Point process probabilities
        raw_p = self.variational_point_process.raw_probabilities[mask]
        _reregister(self.variational_point_process, "raw_probabilities", raw_p)

        # Inducing inputs
        x_u = strat.inducing_points.data[mask]
        _reregister(strat, "inducing_points", x_u)

        # Inducing outputs
        def _chol(L):
            return psd_safe_cholesky(L @ L.T + I * 1e-8)

        I = torch.eye(len(x_u), device=x_u.device)
        m = dist.variational_mean
        L = dist.chol_variational_covar

        if m.dim() == 1:
            m = m.data[mask]
        elif m.dim() == 2:
            m = m.data[:, mask]
        else:
            raise NotImplementedError

        if L.dim() == 2:
            S = (L @ L.T)[mask][:, mask]
            L = psd_safe_cholesky(S + I * 1e-8)
        elif L.dim() == 3:
            S = (L @ L.transpose(1, 2))[:, mask][:, :, mask]
            L = psd_safe_cholesky(S + I * 1e-8)
        else:
            raise NotImplementedError

        _reregister(dist, "variational_mean", m)
        _reregister(dist, "chol_variational_covar", L)

        self.clear_caches()
Ejemplo n.º 9
0
    def _update_covar(self) -> None:
        r"""Update values derived from the data and hyperparameters

        covar, covar_chol, and covar_inv will be of shape batch_shape x n x n
        """
        self.covar = self._calc_covar(self.datapoints, self.datapoints)
        self.covar_chol = psd_safe_cholesky(self.covar)
        self.covar_inv = self._batch_chol_inv(self.covar_chol)
Ejemplo n.º 10
0
 def test_psd_safe_cholesky_pd(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     for dtype in (torch.float, torch.double):
         for batch_mode in (False, True):
             if batch_mode:
                 A = self._gen_test_psd().to(device=device, dtype=dtype)
                 D = torch.eye(2).type_as(A).unsqueeze(0).repeat(2, 1, 1)
             else:
                 A = self._gen_test_psd()[0].to(device=device, dtype=dtype)
                 D = torch.eye(2).type_as(A)
             A += D
             # basic
             L = torch.cholesky(A)
             L_safe = psd_safe_cholesky(A)
             self.assertTrue(torch.allclose(L, L_safe))
             # upper
             L = torch.cholesky(A, upper=True)
             L_safe = psd_safe_cholesky(A, upper=True)
             self.assertTrue(torch.allclose(L, L_safe))
             # output tensors
             L = torch.empty_like(A)
             L_safe = torch.empty_like(A)
             torch.cholesky(A, out=L)
             psd_safe_cholesky(A, out=L_safe)
             self.assertTrue(torch.allclose(L, L_safe))
             # output tensors, upper
             torch.cholesky(A, upper=True, out=L)
             psd_safe_cholesky(A, upper=True, out=L_safe)
             self.assertTrue(torch.allclose(L, L_safe))
             # make sure jitter doesn't do anything if p.d.
             L = torch.cholesky(A)
             L_safe = psd_safe_cholesky(A, jitter=1e-2)
             self.assertTrue(torch.allclose(L, L_safe))
Ejemplo n.º 11
0
    def update_variational_distribution(self, x_new, y_new):
        m_b, S_b = self._update_variational_moments(x_new, y_new)

        q_mean = self.variational_strategy._variational_distribution.variational_mean
        q_mean.data.copy_(m_b.squeeze(-1))

        upper_new_covar = psd_safe_cholesky(S_b, jitter=self._jitter)
        upper_q_covar = self.variational_strategy._variational_distribution.chol_variational_covar
        upper_q_covar.copy_(upper_new_covar)
        self.variational_strategy.variational_params_initialized.fill_(1)
Ejemplo n.º 12
0
    def _update_covar(self, datapoints: Tensor) -> None:
        r"""Update values derived from the data and hyperparameters

        covar, covar_chol, and covar_inv will be of shape batch_shape x n x n

        Args:
            datapoints: (Transformed) datapoints for finding f_max
        """
        self.covar = self._calc_covar(datapoints, datapoints)
        self.covar_chol = psd_safe_cholesky(self.covar)
        self.covar_inv = self._batch_chol_inv(self.covar_chol)
Ejemplo n.º 13
0
    def sample(self, num_context, num_target):
        r"""
        Args:
            num_context (int): Number of context points at the sample.
            num_target (int): Number of target points at the sample.

        Returns:
            :class:`Tensor`.
                Different between train mode and test mode:

                *`train`: `num_context x x_dim`, `num_context x y_dim`, `num_total x x_dim`, `num_total x y_dim`
                *`test`: `num_context x x_dim`, `num_context x y_dim`, `400 x x_dim`, `400 x y_dim`
        """
        if self.train:
            num_total = num_context + num_target
            x_values = torch.empty(num_total,
                                   self.x_dim).uniform_(*self.data_range)
        else:
            lower, upper = self.data_range
            num_total = int((upper - lower) / 0.01 + 1)
            x_values = torch.linspace(self.data_range[0], self.data_range[1],
                                      num_total).unsqueeze(-1)

        if self.random_params:
            length_scale = torch.empty(self.y_dim, self.x_dim).uniform_(
                0.1, self.length_scale)  # [y, x]
            output_scale = torch.empty(self.y_dim).uniform_(
                0.1, self.output_scale)  # [y]
        else:
            length_scale = torch.full((self.y_dim, self.x_dim),
                                      self.length_scale)
            output_scale = torch.full((self.y_dim, ), self.output_scale)

        # [y_dim, num_total, num_total]
        covariance = self.kernel(x_values, length_scale, output_scale)

        cholesky = psd_safe_cholesky(covariance)

        # [num_total, num_total] x [] = []
        y_values = cholesky.matmul(torch.randn(self.y_dim, num_total,
                                               1)).squeeze(2).transpose(0, 1)

        if self.train:
            context_x = x_values[:num_context, :]
            context_y = y_values[:num_context, :]
        else:
            idx = torch.randperm(num_total)
            context_x = torch.gather(x_values, 0,
                                     idx[:num_context].unsqueeze(-1))
            context_y = torch.gather(y_values, 0,
                                     idx[:num_context].unsqueeze(-1))

        return context_x, context_y, x_values, y_values
Ejemplo n.º 14
0
    def _update_variational_moments(self, x, y):
        C = self.current_C_matrix(x)
        c = self.current_c_vec(x, y)
        z_b = self.variational_strategy.inducing_points
        Kbb = self.covar_module(z_b).evaluate()
        L = psd_safe_cholesky(Kbb + C.evaluate(),
                              upper=False,
                              jitter=self._jitter)
        m_b = Kbb @ torch.cholesky_solve(c, L, upper=False)
        S_b = Kbb @ torch.cholesky_solve(Kbb, L, upper=False)

        return m_b, S_b
Ejemplo n.º 15
0
        def _update_inducing_points(gp, Z, m, S=None, L=None):
            # Setting the inducing points for a given layer
            assert (S is None) ^ (L is None)

            strat = gp.variational_strategy
            dist = strat._variational_distribution

            strat.inducing_points.data = Z
            dist.variational_mean.data = m
            if S is not None:
                L = psd_safe_cholesky(S)
            dist.chol_variational_covar.data = L
            gp.clear_caches()
Ejemplo n.º 16
0
    def _inducing_inv_root(self):
        if not self.training and hasattr(self, "_cached_kernel_inv_root"):
            return self._cached_kernel_inv_root
        else:
            chol = psd_safe_cholesky(self._inducing_mat, upper=True)
            eye = torch.eye(chol.size(-1),
                            device=chol.device,
                            dtype=chol.dtype)
            inv_root = torch.triangular_solve(eye, chol)[0]

            res = inv_root
            if not self.training:
                self._cached_kernel_inv_root = res
            return res
Ejemplo n.º 17
0
    def initialize_variational_dist(self):
        """Initialize variational distribution.

        Describes what distribution to pass to the VariationalDistribution to initialize
        with. Most commonly, this should be the prior distribution for the inducing
        points, N(m_u, K_uu). However, if a subclass assumes a different
        parameterization of the variational distribution, it may need to modify what the
        prior is with respect to that reparameterization.
        """
        prior_dist = self.prior_distribution
        eval_prior_dist = torch.distributions.MultivariateNormal(
            loc=prior_dist.mean,
            scale_tril=psd_safe_cholesky(prior_dist.covariance_matrix),
        )
        self.variational_distribution.initialize_variational_distribution(
            eval_prior_dist)
Ejemplo n.º 18
0
    def update_distribution(self, model, X, Y, X_u, likelihood):
        n_u = len(X_u)
        v = likelihood.noise_covar.noise
        b = 1 / v

        full_inputs = torch.cat([X_u, X], dim=-2)
        full_output = model.forward(full_inputs)
        full_covar = full_output.lazy_covariance_matrix

        K_MM = full_covar[..., :n_u, :n_u].add_jitter()
        K_MN = full_covar[..., :n_u, n_u:].evaluate()
        K_MM_inv = torch.inverse(K_MM.evaluate())

        a = K_MM.inv_matmul(K_MN)
        inner = a.matmul(a.T)

        A = b * inner + K_MM_inv
        S = torch.inverse(A)
        m = b * S @ K_MM_inv @ K_MN @ Y

        self.variational_mean = m.flatten()
        self.chol_variational_covar = psd_safe_cholesky(S)
Ejemplo n.º 19
0
    def sample(self, jitter=1e-5):
        """ Sample the discretized GRF on the whole grid.


        Returns
        -------
        Z: (M) Tensor
            The sampled value of Z_{s_i} component l_i.
        jitter: float
            Jitter to add if covariance matrix is not diagonalisable.

        """
        K = self.covariance_mat.list
        mu = self.mean_vec.list

        # Sample M independent N(0, 1) RVs.
        # TODO: Determine if this is better than doing Cholesky ourselves.
        lower_chol = psd_safe_cholesky(K, jitter=jitter)
        distr = MultivariateNormal(loc=mu, scale_tril=lower_chol)
        sample = distr.sample()

        return GeneralizedVector.from_list(sample.float(), self.n_points,
                                           self.n_out)
Ejemplo n.º 20
0
    def current_C_matrix(self, x):
        sigma2 = self.likelihood.noise
        z_b = self.variational_strategy.inducing_points
        Kbf = self.covar_module(z_b, x).evaluate()
        C1 = Kbf @ Kbf.transpose(-1, -2) / sigma2

        if self._old_C_matrix is None:
            C2 = torch.zeros_like(C1)
        else:
            assert self._old_strat is not None
            assert self._old_kernel is not None
            z_a = self._old_strat.inducing_points.detach()
            Kaa_old = self._old_kernel(z_a).add_jitter(self._jitter).detach()
            C_old = self._old_C_matrix.detach()
            Kab = self.covar_module(z_a, z_b).evaluate()
            Kaa_old_inv_Kab = Kaa_old.inv_matmul(Kab)
            C2 = Kaa_old_inv_Kab.transpose(-1,
                                           -2) @ C_old.matmul(Kaa_old_inv_Kab)

        C = C1 + C2
        L = psd_safe_cholesky(C, upper=False, jitter=self._jitter)
        L = lazy.TriangularLazyTensor(L, upper=False)
        return lazy.CholLazyTensor(L, upper=False)
# Sample and plot
# ------------------------------------------------------
# Sample all components at all locations.
sample = my_discrete_grf.sample()
plot_grid_values(my_grid, sample)

# From now on, we will consider the drawn sample as ground truth.
# ---------------------------------------------------------------
ground_truth = sample
# Save for reproducibility.
np.save("ground_truth.npy", ground_truth.numpy())

# Use it to declare the data feed.
noise_std = torch.tensor([0.1, 0.1])
# Noise distribution
lower_chol = psd_safe_cholesky(torch.diag(noise_std**2))
noise_distr = MultivariateNormal(loc=torch.zeros(n_out), scale_tril=lower_chol)


def data_feed(node_ind):
    noise_realization = noise_distr.sample()
    return ground_truth[node_ind] + noise_realization


my_sensor = DiscreteSensor(my_discrete_grf)

# Excursion threshold.
lower = torch.tensor([2.3, 22.0]).float()

# Get the real excursion set and plot it.
excursion_ground_truth = (sample.isotopic > lower).float()
Ejemplo n.º 22
0
    def __call__(self, x, y):
        sigma2 = self.gp.likelihood.noise
        z_b = self.gp.variational_strategy.inducing_points
        Kff = self.gp.covar_module(x).evaluate()
        Kbf = self.gp.covar_module(z_b, x).evaluate()
        Kbb = self.gp.covar_module(z_b).add_jitter(self.gp._jitter)

        Q1 = Kbf.transpose(-1, -2) @ Kbb.inv_matmul(Kbf)
        Sigma1 = sigma2 * torch.eye(Q1.size(-1)).to(Q1.device)

        # logp term
        if self.gp._old_strat is None:
            num_data = y.size(-2)
            mean = torch.zeros(num_data).to(y.device)
            covar = (Q1 + Sigma1) + self.gp._jitter * torch.eye(
                Q1.size(-2)).to(Q1.device)
            dist = distributions.MultivariateNormal(mean, covar)
            logp_term = dist.log_prob(y.squeeze(-1)).sum() / y.size(-2)
        else:
            z_a = self.gp._old_strat.inducing_points.detach()
            Kba = self.gp.covar_module(z_b, z_a).evaluate()
            Kaa_old = self.gp._old_kernel(z_a).evaluate().detach()
            Q2 = Kba.transpose(-1, -2) @ Kbb.inv_matmul(Kba)
            zero_1 = torch.zeros(Q1.size(-2), Q2.size(-1)).to(Q1.device)
            zero_2 = torch.zeros(Q2.size(-2), Q1.size(-1)).to(Q1.device)
            Q = torch.cat([
                torch.cat([Q1, zero_1], dim=-1),
                torch.cat([zero_2, Q2], dim=-1)
            ],
                          dim=-2)

            C_old = self.gp._old_C_matrix.detach()
            Sigma2 = Kaa_old @ C_old.inv_matmul(Kaa_old)
            Sigma2 = Sigma2 + self.gp._jitter * torch.eye(Sigma2.size(-2)).to(
                Sigma2.device)
            Sigma = torch.cat([
                torch.cat([Sigma1, zero_1], dim=-1),
                torch.cat([zero_2, Sigma2], dim=-1)
            ],
                              dim=-2)

            y_hat = torch.cat([y, self.gp.pseudotargets])
            mean = torch.zeros_like(y_hat.squeeze(-1))
            covar = (Q + Sigma) + self.gp._jitter * torch.eye(Q.size(-2)).to(
                Q.device)
            dist = distributions.MultivariateNormal(mean, covar)
            logp_term = dist.log_prob(y_hat.squeeze(-1)).sum() / y_hat.size(-2)
            num_data = y_hat.size(-2)

        # trace term
        t1 = (Kff - Q1).diag().sum() / sigma2
        t2 = 0
        if self.gp._old_strat is not None:
            LSigma2 = psd_safe_cholesky(Sigma2,
                                        upper=False,
                                        jitter=self.gp._jitter)
            Kaa = self.gp.covar_module(z_a).evaluate().detach()
            Sigma2_inv_Kaa = torch.cholesky_solve(Kaa, LSigma2, upper=False)
            Sigma2_inv_Q2 = torch.cholesky_solve(Q2, LSigma2, upper=False)
            t2 = Sigma2_inv_Kaa.diag().sum() - Sigma2_inv_Q2.diag().sum()
        trace_term = -(t1 + t2) / 2 / num_data

        if self._combine_terms:
            return logp_term + trace_term
        else:
            return logp_term, trace_term, t1 / num_data, t2 / num_data
Ejemplo n.º 23
0
 def _chol(L):
     return psd_safe_cholesky(L @ L.T + I * 1e-8)
Ejemplo n.º 24
0
    def _compute_information_gain(self, X: Tensor, mean_M: Tensor,
                                  variance_M: Tensor,
                                  covar_mM: Tensor) -> Tensor:
        r"""Computes the information gain at the design points `X`.

        Approximately computes the information gain at the design points `X`,
        for both MES with noisy observations and multi-fidelity MES with noisy
        observation and trace observations.

        The implementation is inspired from the paper on multi-fidelity MES by
        Takeno et. al. [Takeno2019mfmves]_. The notations in the comments in this
        function follows the Appendix A in the paper.

        Args:
            X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
                with `1` `d`-dim design point each.
            mean_M, variance_M: `batch_shape x num_fantasies`-dim Tensors of
                `batch_shape` t-batches with `num_fantasies` fantasies.
                `num_fantasies = 1` for non-fantasized models.
                All are obtained without noise.
            covar_mM: `batch_shape x num_fantasies x (1 + num_trace_observations)`
                -dim Tensor. `num_fantasies = 1` for non-fantasized models.
                All are obtained without noise.

        Returns:
            A `num_fantasies x batch_shape`-dim Tensor of information gains at the
            given design points `X`.
        """

        # compute the std_m, variance_m with noisy observation
        posterior_m = self.model.posterior(X.unsqueeze(-3),
                                           observation_noise=True)
        mean_m = self.weight * posterior_m.mean.squeeze(-1)
        # batch_shape x num_fantasies x (1 + num_trace_observations)
        variance_m = posterior_m.mvn.covariance_matrix
        # batch_shape x num_fantasies x (1 + num_trace_observations)^2
        check_no_nans(variance_m)

        # compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
        samples_m = self.weight * self.sampler(posterior_m).squeeze(-1)
        # s_m x batch_shape x num_fantasies x (1 + num_trace_observations)
        L = psd_safe_cholesky(variance_m)
        temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1),
                                         L).transpose(-2, -1)
        # equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
        # batch_shape x num_fantasies x 1 x (1 + num_trace_observations)

        mean_pt1 = torch.matmul(temp_term, (samples_m - mean_m).unsqueeze(-1))
        mean_new = mean_pt1.squeeze(-1).squeeze(-1) + mean_M
        # s_m x batch_shape x num_fantasies
        variance_pt1 = torch.matmul(temp_term, covar_mM.unsqueeze(-1))
        variance_new = variance_M - variance_pt1.squeeze(-1).squeeze(-1)
        # batch_shape x num_fantasies
        stdv_new = variance_new.clamp_min(CLAMP_LB).sqrt()
        # batch_shape x num_fantasies

        # define normal distribution to compute cdf and pdf
        normal = torch.distributions.Normal(
            torch.zeros(1, device=X.device, dtype=X.dtype),
            torch.ones(1, device=X.device, dtype=X.dtype),
        )

        # Compute p(fM <= f* | ym, x, Dt)
        view_shape = ([self.num_mv_samples] + [1] * (len(X.shape) - 2) +
                      [self.num_fantasies]
                      )  # s_M x batch_shape x num_fantasies
        if self.X_pending is None:
            view_shape[-1] = 1
        max_vals = self.posterior_max_values.view(view_shape).unsqueeze(1)
        # s_M x 1 x batch_shape x num_fantasies
        normalized_mvs_new = (max_vals - mean_new) / stdv_new
        # s_M x s_m x batch_shape x num_fantasies =
        # s_M x 1 x batch_shape x num_fantasies - s_m x batch_shape x num_fantasies
        cdf_mvs_new = normal.cdf(normalized_mvs_new).clamp_min(CLAMP_LB)

        # Compute p(fM <= f* | x, Dt)
        stdv_M = variance_M.sqrt()
        normalized_mvs = (max_vals - mean_M) / stdv_M
        # s_M x 1 x batch_shape x num_fantasies =
        # s_M x 1 x 1 x num_fantasies - batch_shape x num_fantasies
        cdf_mvs = normal.cdf(normalized_mvs).clamp_min(CLAMP_LB)
        # s_M x 1 x batch_shape x num_fantasies

        # Compute log(p(ym | x, Dt))
        log_pdf_fm = posterior_m.mvn.log_prob(self.weight *
                                              samples_m).unsqueeze(0)
        # 1 x s_m x batch_shape x num_fantasies

        # H0 = H(ym | x, Dt)
        H0 = posterior_m.mvn.entropy()  # batch_shape x num_fantasies

        # regression adjusted H1 estimation, H1_hat = H1_bar - beta * (H0_bar - H0)
        # H1 = E_{f*|x, Dt}[H(ym|f*, x, Dt)]
        Z = cdf_mvs_new / cdf_mvs  # s_M x s_m x batch_shape x num_fantasies
        h1 = -Z * Z.log(
        ) - Z * log_pdf_fm  # s_M x s_m x batch_shape x num_fantasies
        check_no_nans(h1)
        dim = [0, 1]  # dimension of fm samples, fM samples
        H1_bar = h1.mean(dim=dim)
        h0 = -log_pdf_fm
        H0_bar = h0.mean(dim=dim)
        cov = ((h1 - H1_bar) * (h0 - H0_bar)).mean(dim=dim)
        beta = cov / (h0.var(dim=dim) * h1.var(dim=dim)).sqrt()
        H1_hat = H1_bar - beta * (H0_bar - H0)
        ig = H0 - H1_hat  # batch_shape x num_fantasies
        ig = ig.permute(-1,
                        *range(ig.dim() - 1))  # num_fantasies x batch_shape
        return ig
Ejemplo n.º 25
0
    def update_variational_parameters(self,
                                      new_x,
                                      new_y,
                                      new_inducing_points=None):
        # if new_inducing_points = None
        # this version of the variational update does NOT assume that the inducing points
        # are moving around as we add new data. because we are using gradient based updates
        # to optimize them, we do not compute any type of randomized updates to them as in
        # bui et al.
        if new_inducing_points is None:
            new_inducing_points = self.variational_strategy.inducing_points.detach(
            ).clone()

        self.set_streaming(True)

        with torch.no_grad():
            # self.register_streaming_loss(self.variational_strategy.variational_distribution)
            self.register_streaming_loss()

            if len(new_y.shape) == 1:
                new_y = new_y.view(-1, 1)

            S_a = self.variational_strategy.variational_distribution.lazy_covariance_matrix
            K_aa_old = self.variational_strategy.prior_distribution.lazy_covariance_matrix
            m_a = self.variational_strategy.variational_distribution.mean

            D_a_inv = (S_a.evaluate().inverse() -
                       K_aa_old.evaluate().inverse())
            # compute D S_a^{-1} m_a
            pseudo_points = torch.solve(
                S_a.inv_matmul(m_a).unsqueeze(-1), D_a_inv)[0]
            # stack y and the pseudo points
            hat_y = torch.cat((new_y.view(-1, 1), pseudo_points))

            # we now create Sigma_\hat y = blockdiag(\sigma^2 I; D_a)
            noise_diag = self.likelihood.noise * torch.eye(new_y.size(-2)).to(
                new_y.device)
            zero_part = torch.zeros(new_y.size(-2),
                                    pseudo_points.size(0)).to(new_y.device)
            tophalf = torch.cat((noise_diag, zero_part), -1)
            bottomhalf = torch.cat((zero_part.t(), D_a_inv.inverse()), -1)
            sigma_hat_y = torch.cat((tophalf, bottomhalf))

            # stack the data to be able to compute covariances with it
            # (x, a)
            stacked_data = torch.cat(
                (new_x, self.variational_strategy.inducing_points))

            K_fb = self.covar_module(stacked_data, new_inducing_points)
            K_bb = self.covar_module(new_inducing_points)

            # C = K_{hat f b} K_{bb}^{-1} K_{b hat f} + Sigma_\hat y
            pred_cov = K_fb @ (K_bb.inv_matmul(
                K_fb.evaluate().t())) + sigma_hat_y

            # the new mean is K_{hat f b} C^{-1} \hat y
            new_mean = K_fb.t() @ torch.solve(
                hat_y, pred_cov)[0].squeeze(-1).detach().contiguous()

            # the new covariance is K_bb - K_{hat f b} C^{-1} K_{b hat f}
            new_cov = K_bb - K_fb.t() @ torch.solve(K_fb.evaluate(),
                                                    pred_cov)[0]
            new_variational_chol = psd_safe_cholesky(
                new_cov.evaluate(),
                jitter=cholesky_jitter.value()).detach().contiguous()
            self.variational_strategy._variational_distribution.variational_mean.data.mul_(
                0.).add_(new_mean)
            self.variational_strategy._variational_distribution.chol_variational_covar.data.mul_(
                0.).add_(new_variational_chol)
            self.variational_strategy.inducing_points.data.mul_(0.).add_(
                new_inducing_points.detach())
Ejemplo n.º 26
0
 def test_psd_safe_cholesky_nan(self, cuda=False):
     A = self._gen_test_psd().sqrt()
     with self.assertRaises(NanError) as ctx:
         psd_safe_cholesky(A)
         self.assertTrue("NaN" in ctx.exception)
Ejemplo n.º 27
0
    def _compute_information_gain(self, X: Tensor, mean_M: Tensor,
                                  variance_M: Tensor,
                                  covar_mM: Tensor) -> Tensor:
        r"""Computes the information gain at the design points `X`.

        Approximately computes the information gain at the design points `X`,
        for both MES with noisy observations and multi-fidelity MES with noisy
        observation and trace observations.

        The implementation is inspired from the papers on multi-fidelity MES by
        [Takeno2020mfmves]_. The notation in the comments in this function follows
        the Appendix C of [Takeno2020mfmves]_.

        `num_fantasies = 1` for non-fantasized models.

        Args:
            X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
                with `1` `d`-dim design point each.
            mean_M: A `batch_shape x num_fantasies x (m)`-dim Tensor of means.
            variance_M: A `batch_shape x num_fantasies x (m)`-dim Tensor of variances.
            covar_mM: A
                `batch_shape x num_fantasies x (m) x (1 + num_trace_observations)`-dim
                Tensor of covariances.

        Returns:
            A `num_fantasies x batch_shape`-dim Tensor of information gains at the
            given design points `X` (`num_fantasies=1` for non-fantasized models).
        """
        # compute the std_m, variance_m with noisy observation
        posterior_m = self.model.posterior(X.unsqueeze(-3),
                                           observation_noise=True)
        # batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
        mean_m = self.weight * posterior_m.mean.squeeze(-1)
        # batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
        variance_m = posterior_m.mvn.covariance_matrix
        check_no_nans(variance_m)

        # compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
        samples_m = self.weight * self.sampler(posterior_m).squeeze(-1)
        # s_m x batch_shape x num_fantasies x (m) (1 + num_trace_observations)
        L = psd_safe_cholesky(variance_m)
        temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1),
                                         L).transpose(-2, -1)
        # equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
        # batch_shape x num_fantasies (m) x 1 x (1 + num_trace_observations)

        mean_pt1 = torch.matmul(temp_term, (samples_m - mean_m).unsqueeze(-1))
        mean_new = mean_pt1.squeeze(-1).squeeze(-1) + mean_M
        # s_m x batch_shape x num_fantasies x (m)
        variance_pt1 = torch.matmul(temp_term, covar_mM.unsqueeze(-1))
        variance_new = variance_M - variance_pt1.squeeze(-1).squeeze(-1)
        # batch_shape x num_fantasies x (m)
        stdv_new = variance_new.clamp_min(CLAMP_LB).sqrt()
        # batch_shape x num_fantasies x (m)

        # define normal distribution to compute cdf and pdf
        normal = torch.distributions.Normal(
            torch.zeros(1, device=X.device, dtype=X.dtype),
            torch.ones(1, device=X.device, dtype=X.dtype),
        )

        # Compute p(fM <= f* | ym, x, Dt)
        view_shape = torch.Size([
            self.posterior_max_values.shape[0],
            # add 1s to broadcast across the batch_shape of X
            *[1 for _ in range(X.ndim - self.posterior_max_values.ndim)],
            *self.posterior_max_values.shape[1:],
        ])  # s_M x batch_shape x num_fantasies x (m)
        max_vals = self.posterior_max_values.view(view_shape).unsqueeze(1)
        # s_M x 1 x batch_shape x num_fantasies x (m)
        normalized_mvs_new = (max_vals - mean_new) / stdv_new
        # s_M x s_m x batch_shape x num_fantasies x (m)  =
        #   s_M x 1 x batch_shape x num_fantasies x (m)
        #   - s_m x batch_shape x num_fantasies x (m)
        cdf_mvs_new = normal.cdf(normalized_mvs_new).clamp_min(CLAMP_LB)

        # Compute p(fM <= f* | x, Dt)
        stdv_M = variance_M.sqrt()
        normalized_mvs = (max_vals - mean_M) / stdv_M
        # s_M x 1 x batch_shape x num_fantasies  x (m) =
        # s_M x 1 x 1 x num_fantasies x (m) - batch_shape x num_fantasies x (m)
        cdf_mvs = normal.cdf(normalized_mvs).clamp_min(CLAMP_LB)
        # s_M x 1 x batch_shape x num_fantasies x (m)

        # Compute log(p(ym | x, Dt))
        log_pdf_fm = posterior_m.mvn.log_prob(self.weight *
                                              samples_m).unsqueeze(0)
        # 1 x s_m x batch_shape x num_fantasies x (m)

        # H0 = H(ym | x, Dt)
        H0 = posterior_m.mvn.entropy()  # batch_shape x num_fantasies x (m)

        # regression adjusted H1 estimation, H1_hat = H1_bar - beta * (H0_bar - H0)
        # H1 = E_{f*|x, Dt}[H(ym|f*, x, Dt)]
        Z = cdf_mvs_new / cdf_mvs  # s_M x s_m x batch_shape x num_fantasies x (m)
        # s_M x s_m x batch_shape x num_fantasies x (m)
        h1 = -Z * Z.log() - Z * log_pdf_fm
        check_no_nans(h1)
        dim = [0, 1]  # dimension of fm samples, fM samples
        H1_bar = h1.mean(dim=dim)
        h0 = -log_pdf_fm
        H0_bar = h0.mean(dim=dim)
        cov = ((h1 - H1_bar) * (h0 - H0_bar)).mean(dim=dim)
        beta = cov / (h0.var(dim=dim) * h1.var(dim=dim)).sqrt()
        H1_hat = H1_bar - beta * (H0_bar - H0)
        ig = H0 - H1_hat  # batch_shape x num_fantasies x (m)
        if self.posterior_max_values.ndim == 2:
            permute_idcs = [-1, *range(ig.ndim - 1)]
        else:
            permute_idcs = [-2, *range(ig.ndim - 2), -1]
        ig = ig.permute(*permute_idcs)  # num_fantasies x batch_shape x (m)
        return ig
Ejemplo n.º 28
0
def sample_cached_cholesky(
    posterior: GPyTorchPosterior,
    baseline_L: Tensor,
    q: int,
    base_samples: Tensor,
    sample_shape: torch.Size,
    max_tries: int = 6,
) -> Tensor:
    r"""Get posterior samples at the `q` new points from the joint multi-output posterior.

    Args:
        posterior: The joint posterior is over (X_baseline, X).
        baseline_L: The baseline lower triangular cholesky factor.
        q: The number of new points in X.
        base_samples: The base samples.
        sample_shape: The sample shape.
        max_tries: The number of tries for computing the Cholesky
            decomposition with increasing jitter.


    Returns:
        A `sample_shape x batch_shape x q x m`-dim tensor of posterior
            samples at the new points.
    """
    # compute bottom left covariance block
    if isinstance(posterior.mvn, MultitaskMultivariateNormal):
        lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn)
    else:
        lazy_covar = posterior.mvn.lazy_covariance_matrix
    # Get the `q` new rows of the batched covariance matrix
    bottom_rows = lazy_covar[..., -q:, :].evaluate()
    # The covariance in block form is:
    # [K(X_baseline, X_baseline), K(X_baseline, X)]
    # [K(X, X_baseline), K(X, X)]
    # bl := K(X, X_baseline)
    # br := K(X, X)
    # Get bottom right block of new covariance
    bl, br = bottom_rows.split([bottom_rows.shape[-1] - q, q], dim=-1)
    # Solve Ax = b
    # where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T
    # and bl_chol := x^T
    # bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
    # of the cholesky decomposition
    bl_chol = torch.triangular_solve(bl.transpose(-2, -1),
                                     baseline_L,
                                     upper=False).solution.transpose(-2, -1)
    # Compute the new bottom right block of the Cholesky
    # decomposition via:
    # Cholesky(K(X, X) - bl_chol @ bl_chol^T)
    br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1)
    # TODO: technically we should make sure that we add a
    # consistent nugget to the cached covariance and the new block
    br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries)
    # Create a `(batch_shape) x q x (n+q)`-dim tensor containing the
    # `q` new bottom rows of the Cholesky decomposition
    new_Lq = torch.cat([bl_chol, br_chol], dim=-1)
    mean = posterior.mvn.mean
    base_samples = _reshape_base_samples(
        base_samples=base_samples,
        sample_shape=sample_shape,
        posterior=posterior,
    )
    if not isinstance(posterior.mvn, MultitaskMultivariateNormal):
        # add output dim
        mean = mean.unsqueeze(-1)
        # add batch dim corresponding to output dim
        new_Lq = new_Lq.unsqueeze(-3)
    new_mean = mean[..., -q:, :]
    res = (new_Lq.matmul(base_samples).add(
        new_mean.transpose(-1, -2).unsqueeze(-1)).permute(
            -1, *range(posterior.mvn.loc.dim() - 1), -2, -3).contiguous())
    contains_nans = torch.isnan(res).any()
    contains_infs = torch.isinf(res).any()
    if contains_nans or contains_infs:
        suffix_args = []
        if contains_nans:
            suffix_args.append("nans")
        if contains_infs:
            suffix_args.append("infs")
        suffix = " and ".join(suffix_args)
        raise NanError(f"Samples contain {suffix}.")
    return res