Exemple #1
0
    def block_logdet(self, var, cov_mat_root):
        var = flatten(var)

        cov_mat_lt = RootLazyTensor(cov_mat_root.t())
        var_lt = DiagLazyTensor(var + 1e-6)
        covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)

        return covar_lt.log_det()
 def create_lazy_tensor(self):
     tensor = torch.randn(3, 5, 5)
     tensor = tensor.transpose(-1, -2).matmul(tensor).detach()
     diag = torch.tensor(
         [[1.0, 2.0, 4.0, 2.0, 3.0], [2.0, 1.0, 2.0, 1.0, 4.0], [1.0, 2.0, 2.0, 3.0, 4.0]], requires_grad=True
     )
     return AddedDiagLazyTensor(NonLazyTensor(tensor), DiagLazyTensor(diag))
Exemple #3
0
    def compute_ll_for_block(self, vec, mean, var, cov_mat_root):
        vec = flatten(vec)
        mean = flatten(mean)
        var = flatten(var)

        cov_mat_lt = RootLazyTensor(cov_mat_root.t())
        var_lt = DiagLazyTensor(var + 1e-6)
        covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)
        qdist = MultivariateNormal(mean, covar_lt)

        with gpytorch.settings.num_trace_samples(1) and gpytorch.settings.max_cg_iterations(25):
            return qdist.log_prob(vec)
Exemple #4
0
def _get_test_posterior(batch_shape: torch.Size,
                        q: int = 1,
                        m: int = 1,
                        interleaved: bool = True,
                        lazy: bool = False,
                        independent: bool = False,
                        **tkwargs) -> GPyTorchPosterior:
    r"""Generate a Posterior for testing purposes.

    Args:
        batch_shape: The batch shape of the data.
        q: The number of candidates
        m: The number of outputs.
        interleaved: A boolean indicating the format of the
            MultitaskMultivariateNormal
        lazy: A boolean indicating if the posterior should be lazy
        indepedent: A boolean indicating whether the outputs are independent
        tkwargs: `device` and `dtype` tensor constructor kwargs.


    """
    if independent:
        mvns = []
        for _ in range(m):
            mean = torch.rand(*batch_shape, q, **tkwargs)
            a = torch.rand(*batch_shape, q, q, **tkwargs)
            covar = a @ a.transpose(-1, -2)
            flat_diag = torch.rand(*batch_shape, q, **tkwargs)
            covar = covar + torch.diag_embed(flat_diag)
            mvns.append(MultivariateNormal(mean, covar))
        mtmvn = MultitaskMultivariateNormal.from_independent_mvns(mvns)
    else:
        mean = torch.rand(*batch_shape, q, m, **tkwargs)
        a = torch.rand(*batch_shape, q * m, q * m, **tkwargs)
        covar = a @ a.transpose(-1, -2)
        flat_diag = torch.rand(*batch_shape, q * m, **tkwargs)
        if lazy:
            covar = AddedDiagLazyTensor(covar, DiagLazyTensor(flat_diag))
        else:
            covar = covar + torch.diag_embed(flat_diag)
        mtmvn = MultitaskMultivariateNormal(mean,
                                            covar,
                                            interleaved=interleaved)
    return GPyTorchPosterior(mtmvn)
Exemple #5
0
    def test_precond_solve(self):
        seed = 4
        torch.random.manual_seed(seed)

        tensor = torch.randn(1000, 800)
        diag = torch.abs(torch.randn(1000))

        standard_lt = AddedDiagLazyTensor(RootLazyTensor(tensor),
                                          DiagLazyTensor(diag))
        evals, evecs = standard_lt.symeig(eigenvectors=True)

        # this preconditioner is a simple example of near deflation
        def nonstandard_preconditioner(self):
            top_100_evecs = evecs[:, :100]
            top_100_evals = evals[:100] + 0.2 * torch.randn(100)

            precond_lt = RootLazyTensor(
                top_100_evecs @ torch.diag(top_100_evals**0.5))
            logdet = top_100_evals.log().sum()

            def precond_closure(rhs):
                rhs2 = top_100_evecs.t() @ rhs
                return top_100_evecs @ torch.diag(1.0 / top_100_evals) @ rhs2

            return precond_closure, precond_lt, logdet

        overrode_lt = AddedDiagLazyTensor(
            RootLazyTensor(tensor),
            DiagLazyTensor(diag),
            preconditioner_override=nonstandard_preconditioner)

        # compute a solve - mostly to make sure that we can actually perform the solve
        rhs = torch.randn(1000, 1)
        standard_solve = standard_lt.inv_matmul(rhs)
        overrode_solve = overrode_lt.inv_matmul(rhs)

        # gut checking that our preconditioner is not breaking anything
        self.assertEqual(standard_solve.shape, overrode_solve.shape)
        self.assertLess(
            torch.norm(standard_solve - overrode_solve) /
            standard_solve.norm(), 1.0)
 def create_lazy_tensor(self):
     tensor = torch.randn(3, 5, 5)
     tensor = tensor.transpose(-1, -2).matmul(tensor)
     tensor.requires_grad_(True)
     diag = torch.tensor([[1., 2., 4., 2., 3.], [2., 1., 2., 1., 4.], [1., 2., 2., 3., 4.]], requires_grad=True)
     return AddedDiagLazyTensor(NonLazyTensor(tensor), DiagLazyTensor(diag))
    def test_added_diag_lt(self, N=10000, p=20, use_cuda=False, seed=1):

        torch.manual_seed(seed)

        if torch.cuda.is_available() and use_cuda:
            print("Using cuda")
            device = torch.device("cuda")
            torch.cuda.manual_seed_all(seed)
        else:
            device = torch.device("cpu")

        D = torch.randn(N, p, device=device)
        A = torch.randn(N, device=device).abs() * 1e-3 + 0.1

        # this is a lazy tensor for DD'
        D_lt = RootLazyTensor(D)

        # this is a lazy tensor for diag(A)
        diag_term = DiagLazyTensor(A)

        # DD' + diag(A)
        lowrank_pdiag_lt = AddedDiagLazyTensor(diag_term, D_lt)

        # z \sim N(0,I), mean = 1
        z = torch.randn(N, device=device)
        mean = torch.ones(N, device=device)

        diff = mean - z

        print(lowrank_pdiag_lt.log_det())
        logdet = lowrank_pdiag_lt.log_det()
        inv_matmul = lowrank_pdiag_lt.inv_matmul(diff.unsqueeze(1)).squeeze(1)
        inv_matmul_quad = torch.dot(diff, inv_matmul)
        """inv_matmul_quad_qld, logdet_qld = lowrank_pdiag_lt.inv_quad_log_det(inv_quad_rhs=diff.unsqueeze(1), log_det = True)
        
        """
        """from gpytorch.functions._inv_quad_log_det import InvQuadLogDet
        iqld_construct = InvQuadLogDet(gpytorch.lazy.lazy_tensor_representation_tree.LazyTensorRepresentationTree(lowrank_pdiag_lt),
                            matrix_shape=lowrank_pdiag_lt.matrix_shape,
                            dtype=lowrank_pdiag_lt.dtype,
                            device=lowrank_pdiag_lt.device,
                            inv_quad=True,
                            log_det=True,
                            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
                            log_det_correction=lowrank_pdiag_lt._preconditioner()[1])
        inv_matmul_quad_qld, logdet_qld = iqld_construct(diff.unsqueeze(1))"""
        num_random_probes = gpytorch.settings.num_trace_samples.value()
        probe_vectors = torch.empty(
            lowrank_pdiag_lt.matrix_shape[-1],
            num_random_probes,
            dtype=lowrank_pdiag_lt.dtype,
            device=lowrank_pdiag_lt.device,
        )
        probe_vectors.bernoulli_().mul_(2).add_(-1)
        probe_vector_norms = torch.norm(probe_vectors, 2, dim=-2, keepdim=True)
        probe_vectors = probe_vectors.div(probe_vector_norms)

        # diff_norm = diff.norm()
        # diff = diff/diff_norm
        rhs = torch.cat([diff.unsqueeze(1), probe_vectors], dim=1)

        solves, t_mat = gpytorch.utils.linear_cg(
            lowrank_pdiag_lt.matmul,
            rhs,
            n_tridiag=num_random_probes,
            max_iter=gpytorch.settings.max_cg_iterations.value(),
            max_tridiag_iter=gpytorch.settings.
            max_lanczos_quadrature_iterations.value(),
            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
        )
        # print(solves)
        inv_matmul_qld = solves[:, 0]  # * diff_norm

        diff_solve = gpytorch.utils.linear_cg(
            lowrank_pdiag_lt.matmul,
            diff.unsqueeze(1),
            max_iter=gpytorch.settings.max_cg_iterations.value(),
            preconditioner=lowrank_pdiag_lt._preconditioner()[0],
        )
        print("diff_solve_norm: ", diff_solve.norm())
        print(
            "diff between multiple linear_cg: ",
            (inv_matmul_qld.unsqueeze(1) - diff_solve).norm() /
            diff_solve.norm(),
        )

        eigenvalues, eigenvectors = gpytorch.utils.lanczos.lanczos_tridiag_to_diag(
            t_mat)
        slq = gpytorch.utils.StochasticLQ()
        log_det_term, = slq.evaluate(
            lowrank_pdiag_lt.matrix_shape,
            eigenvalues,
            eigenvectors,
            [lambda x: x.log()],
        )
        logdet_qld = log_det_term + lowrank_pdiag_lt._preconditioner()[1]

        print("Log det difference: ",
              (logdet - logdet_qld).norm() / logdet.norm())
        print(
            "inv matmul difference: ",
            (inv_matmul - inv_matmul_qld).norm() / inv_matmul_quad.norm(),
        )

        # N(1, DD' + diag(A))
        lazydist = MultivariateNormal(mean, lowrank_pdiag_lt)
        lazy_lprob = lazydist.log_prob(z)

        # exact log probability with Cholesky decomposition
        exact_dist = torch.distributions.MultivariateNormal(
            mean,
            lowrank_pdiag_lt.evaluate().float())
        exact_lprob = exact_dist.log_prob(z)

        print(lazy_lprob, exact_lprob)
        rel_error = torch.norm(lazy_lprob - exact_lprob) / exact_lprob.norm()

        self.assertLess(rel_error.cpu().item(), 0.01)