def _approx_diag(self):
     base_diag_root = self.base_lazy_tensor.diag().sqrt()
     left_res = left_interp(self.left_interp_indices,
                            self.left_interp_values,
                            base_diag_root.unsqueeze(-1))
     right_res = left_interp(self.right_interp_indices,
                             self.right_interp_values,
                             base_diag_root.unsqueeze(-1))
     res = left_res * right_res
     return res.squeeze(-1)
 def diag(self):
     if isinstance(self.base_lazy_tensor, RootLazyTensor) and isinstance(
             self.base_lazy_tensor.root, NonLazyTensor):
         left_interp_vals = left_interp(
             self.left_interp_indices, self.left_interp_values,
             self.base_lazy_tensor.root.evaluate())
         right_interp_vals = left_interp(
             self.right_interp_indices, self.right_interp_values,
             self.base_lazy_tensor.root.evaluate())
         return (left_interp_vals * right_interp_vals).sum(-1)
     else:
         return super(InterpolatedLazyTensor, self).diag()
    def matmul(self, tensor):
        # We're using a custom matmul here, because it is significantly faster than
        # what we get from the function factory.
        # The _matmul_closure is optimized for repeated calls, such as for inv_matmul

        if tensor.ndimension() == 1:
            is_vector = True
            tensor = tensor.unsqueeze(-1)
        else:
            is_vector = False

        # right_interp^T * tensor
        base_size = self.base_lazy_tensor.size(-1)
        right_interp_res = left_t_interp(self.right_interp_indices,
                                         self.right_interp_values, tensor,
                                         base_size)

        # base_lazy_tensor * right_interp^T * tensor
        base_res = self.base_lazy_tensor.matmul(right_interp_res)

        # left_interp * base_lazy_tensor * right_interp^T * tensor
        res = left_interp(self.left_interp_indices, self.left_interp_values,
                          base_res)

        # Squeeze if necessary
        if is_vector:
            res = res.squeeze(-1)
        return res
Ejemplo n.º 4
0
    def forward(self,
                x,
                inducing_points,
                inducing_values,
                variational_inducing_covar=None):
        if variational_inducing_covar is None:
            raise RuntimeError(
                "GridInterpolationVariationalStrategy is only compatible with Gaussian variational "
                f"distributions. Got ({self.variational_distribution.__class__.__name__}."
            )

        variational_distribution = self.variational_distribution

        # Get interpolations
        interp_indices, interp_values = self._compute_grid(x)

        # Compute test mean
        # Left multiply samples by interpolation matrix
        predictive_mean = left_interp(interp_indices, interp_values,
                                      inducing_values.unsqueeze(-1))
        predictive_mean = predictive_mean.squeeze(-1)

        # Compute test covar
        predictive_covar = InterpolatedLazyTensor(
            variational_distribution.lazy_covariance_matrix,
            interp_indices,
            interp_values,
            interp_indices,
            interp_values,
        )
        output = MultivariateNormal(predictive_mean, predictive_covar)
        return output
Ejemplo n.º 5
0
    def exact_predictive_covar(self, test_test_covar, test_train_covar):
        if settings.fast_pred_var.off() and settings.fast_pred_samples.off():
            return super(InterpolatedPredictionStrategy,
                         self).exact_predictive_covar(test_test_covar,
                                                      test_train_covar)

        self._last_test_train_covar = test_train_covar
        test_interp_indices = test_train_covar.left_interp_indices
        test_interp_values = test_train_covar.left_interp_values

        precomputed_cache = self.covar_cache
        fps = settings.fast_pred_samples.on()
        if (fps and precomputed_cache[0] is None) or (
                not fps and precomputed_cache[1] is None):
            pop_from_cache(self, "covar_cache")
            precomputed_cache = self.covar_cache

        # Compute the exact predictive posterior
        if settings.fast_pred_samples.on():
            res = self._exact_predictive_covar_inv_quad_form_root(
                precomputed_cache[0], test_train_covar)
            res = RootLazyTensor(res)
        else:
            root = left_interp(test_interp_indices, test_interp_values,
                               precomputed_cache[1])
            res = test_test_covar + RootLazyTensor(root).mul(-1)
        return res
Ejemplo n.º 6
0
 def exact_predictive_mean(self, test_mean, test_train_covar):
     precomputed_cache = self.mean_cache
     test_interp_indices = test_train_covar.left_interp_indices
     test_interp_values = test_train_covar.left_interp_values
     res = left_interp(test_interp_indices, test_interp_values,
                       precomputed_cache).squeeze(-1) + test_mean
     return res
 def zero_mean_mvn_samples(self, num_samples):
     base_samples = self.base_lazy_tensor.zero_mean_mvn_samples(num_samples)
     batch_iter = tuple(range(1, base_samples.dim()))
     base_samples = base_samples.permute(*batch_iter, 0)
     res = left_interp(self.left_interp_indices, self.left_interp_values,
                       base_samples).contiguous()
     batch_iter = tuple(range(res.dim() - 1))
     return res.permute(-1, *batch_iter).contiguous()
Ejemplo n.º 8
0
 def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache,
                                                test_train_covar):
     # Here the precomputed cache represents K_UU W S,
     # where S S^T = (K_XX + sigma^2 I)^-1
     test_interp_indices = test_train_covar.left_interp_indices
     test_interp_values = test_train_covar.left_interp_values
     res = left_interp(test_interp_indices, test_interp_values,
                       precomputed_cache)
     return res