Ejemplo n.º 1
0
    def test_inv_quad_log_det_many_vectors(self):
        # Forward pass
        actual_inv_quad = self.mat_var_clone.inverse().matmul(
            self.vecs_var_clone).mul(self.vecs_var_clone).sum()
        actual_log_det = self.mat_var_clone.det().log()
        with gpytorch.settings.num_trace_samples(1000):
            nlv = NonLazyTensor(self.mat_var)
            res_inv_quad, res_log_det = nlv.inv_quad_log_det(
                inv_quad_rhs=self.vecs_var, log_det=True)
        self.assertAlmostEqual(res_inv_quad.item(),
                               actual_inv_quad.item(),
                               places=1)
        self.assertAlmostEqual(res_log_det.item(),
                               actual_log_det.item(),
                               places=1)

        # Backward
        actual_inv_quad.backward()
        actual_log_det.backward()
        res_inv_quad.backward(retain_graph=True)
        res_log_det.backward()

        self.assertTrue(
            approx_equal(self.mat_var_clone.grad,
                         self.mat_var.grad,
                         epsilon=1e-1))
        self.assertTrue(
            approx_equal(self.vecs_var_clone.grad, self.vecs_var.grad))
Ejemplo n.º 2
0
    def test_inv_quad_log_det_many_vectors(self):
        # Forward pass
        actual_inv_quad = (torch.cat([
            self.mats_var_clone[0].inverse().unsqueeze(0),
            self.mats_var_clone[1].inverse().unsqueeze(0)
        ]).matmul(self.vecs_var_clone).mul(self.vecs_var_clone).sum(2).sum(1))
        actual_log_det = torch.cat([
            self.mats_var_clone[0].det().log().unsqueeze(0),
            self.mats_var_clone[1].det().log().unsqueeze(0)
        ])
        with gpytorch.settings.num_trace_samples(1000):
            nlv = NonLazyTensor(self.mats_var)
            res_inv_quad, res_log_det = nlv.inv_quad_log_det(
                inv_quad_rhs=self.vecs_var, log_det=True)
        self.assertTrue(
            approx_equal(res_inv_quad, actual_inv_quad, epsilon=1e-1))
        self.assertTrue(approx_equal(res_log_det, actual_log_det,
                                     epsilon=1e-1))

        # Backward
        inv_quad_grad_output = torch.tensor([3, 4], dtype=torch.float)
        log_det_grad_output = torch.tensor([4, 2], dtype=torch.float)
        actual_inv_quad.backward(gradient=inv_quad_grad_output)
        actual_log_det.backward(gradient=log_det_grad_output)
        res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True)
        res_log_det.backward(gradient=log_det_grad_output)

        self.assertTrue(
            approx_equal(self.mats_var_clone.grad,
                         self.mats_var.grad,
                         epsilon=1e-1))
        self.assertTrue(
            approx_equal(self.vecs_var_clone.grad, self.vecs_var.grad))
Ejemplo n.º 3
0
    def test_inv_quad_log_det_many_vectors(self):
        # Forward pass
        actual_inv_quad = self.mat_clone.inverse().matmul(self.vecs_clone).mul(
            self.vecs_clone).sum()
        actual_log_det = self.mat_clone.logdet()
        with gpytorch.settings.num_trace_samples(1000):
            non_lazy_tsr = NonLazyTensor(self.mat)
            res_inv_quad, res_log_det = non_lazy_tsr.inv_quad_log_det(
                inv_quad_rhs=self.vecs, log_det=True)
        self.assertAlmostEqual(res_inv_quad.item(),
                               actual_inv_quad.item(),
                               places=1)
        self.assertAlmostEqual(res_log_det.item(),
                               actual_log_det.item(),
                               places=1)

        # Backward
        actual_inv_quad.backward()
        actual_log_det.backward()
        res_inv_quad.backward(retain_graph=True)
        res_log_det.backward()

        self.assertLess(
            torch.max((self.mat_clone.grad - self.mat.grad).abs()).item(),
            1e-1)
        self.assertLess(
            torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(),
            1e-1)
Ejemplo n.º 4
0
    def test_inv_quad_log_det_many_vectors_improper(self):
        # Forward pass
        actual_inv_quad = (torch.cat([
            mat.inverse().unsqueeze(0) for mat in self.mats_clone
        ]).matmul(self.vecs_clone).mul(self.vecs_clone).sum(2).sum(1))
        actual_log_det = torch.cat(
            [mat.logdet().unsqueeze(0) for mat in self.mats_clone])
        with gpytorch.settings.num_trace_samples(
                2000), gpytorch.settings.skip_logdet_forward(True):
            non_lazy_tsr = NonLazyTensor(self.mats)
            res_inv_quad, res_log_det = non_lazy_tsr.inv_quad_log_det(
                inv_quad_rhs=self.vecs, log_det=True)

        self.assertEqual(res_inv_quad.shape, actual_inv_quad.shape)
        self.assertEqual(res_log_det.shape, actual_log_det.shape)
        self.assertLess(
            torch.max((res_inv_quad - actual_inv_quad).abs()).item(), 1e-1)
        self.assertLess(torch.max(res_log_det.abs()).item(), 1e-1)

        # Backward
        inv_quad_grad_output = torch.randn(5, dtype=torch.float)
        log_det_grad_output = torch.randn(5, dtype=torch.float)
        actual_inv_quad.backward(gradient=inv_quad_grad_output)
        actual_log_det.backward(gradient=log_det_grad_output)
        res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True)
        res_log_det.backward(gradient=log_det_grad_output)

        self.assertLess(
            torch.max((self.mats_clone.grad - self.mats.grad).abs()).item(),
            1e-1)
        self.assertLess(
            torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(),
            1e-1)
Ejemplo n.º 5
0
    def test_inv_quad_log_det_many_vectors(self):
        # Forward pass
        flattened_mats = self.mats_clone.view(-1, *self.mats_clone.shape[-2:])
        actual_inv_quad = (torch.cat([
            mat.inverse().unsqueeze(0) for mat in flattened_mats
        ]).view(self.mats_clone.shape).matmul(self.vecs_clone).mul(
            self.vecs_clone).sum(-2).sum(-1))
        actual_log_det = torch.cat(
            [mat.logdet().unsqueeze(0) for mat in flattened_mats])
        actual_log_det = actual_log_det.view(self.mats_clone.shape[:-2])

        with gpytorch.settings.num_trace_samples(2000):
            non_lazy_tsr = NonLazyTensor(self.mats)
            res_inv_quad, res_log_det = non_lazy_tsr.inv_quad_log_det(
                inv_quad_rhs=self.vecs, log_det=True)

        self.assertEqual(res_inv_quad.shape, actual_inv_quad.shape)
        self.assertEqual(res_log_det.shape, actual_log_det.shape)
        self.assertLess(
            torch.max((res_inv_quad - actual_inv_quad).abs()).item(), 1e-1)
        self.assertLess(
            torch.max((res_log_det - actual_log_det).abs()).item(), 1e-1)

        # Backward
        inv_quad_grad_output = torch.randn(2, 3, dtype=torch.float)
        log_det_grad_output = torch.randn(2, 3, dtype=torch.float)
        actual_inv_quad.backward(gradient=inv_quad_grad_output)
        actual_log_det.backward(gradient=log_det_grad_output)
        res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True)
        res_log_det.backward(gradient=log_det_grad_output)

        self.assertLess(
            torch.max((self.mats_clone.grad - self.mats.grad).abs()).item(),
            1e-1)
        self.assertLess(
            torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(),
            1e-1)