def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = (torch.cat([ self.mats_var_clone[0].inverse().unsqueeze(0), self.mats_var_clone[1].inverse().unsqueeze(0) ]).matmul(self.vecs_var_clone).mul(self.vecs_var_clone).sum(2).sum(1)) actual_log_det = torch.cat([ self.mats_var_clone[0].det().log().unsqueeze(0), self.mats_var_clone[1].det().log().unsqueeze(0) ]) with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyTensor(self.mats_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det( inv_quad_rhs=self.vecs_var, log_det=True) self.assertTrue( approx_equal(res_inv_quad, actual_inv_quad, epsilon=1e-1)) self.assertTrue(approx_equal(res_log_det, actual_log_det, epsilon=1e-1)) # Backward inv_quad_grad_output = torch.tensor([3, 4], dtype=torch.float) log_det_grad_output = torch.tensor([4, 2], dtype=torch.float) actual_inv_quad.backward(gradient=inv_quad_grad_output) actual_log_det.backward(gradient=log_det_grad_output) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) res_log_det.backward(gradient=log_det_grad_output) self.assertTrue( approx_equal(self.mats_var_clone.grad, self.mats_var.grad, epsilon=1e-1)) self.assertTrue( approx_equal(self.vecs_var_clone.grad, self.vecs_var.grad))
def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = self.mat_var_clone.inverse().matmul( self.vecs_var_clone).mul(self.vecs_var_clone).sum() actual_log_det = self.mat_var_clone.det().log() with gpytorch.settings.num_trace_samples(1000): nlv = NonLazyTensor(self.mat_var) res_inv_quad, res_log_det = nlv.inv_quad_log_det( inv_quad_rhs=self.vecs_var, log_det=True) self.assertAlmostEqual(res_inv_quad.item(), actual_inv_quad.item(), places=1) self.assertAlmostEqual(res_log_det.item(), actual_log_det.item(), places=1) # Backward actual_inv_quad.backward() actual_log_det.backward() res_inv_quad.backward(retain_graph=True) res_log_det.backward() self.assertTrue( approx_equal(self.mat_var_clone.grad, self.mat_var.grad, epsilon=1e-1)) self.assertTrue( approx_equal(self.vecs_var_clone.grad, self.vecs_var.grad))
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor( [[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) d = torch.tensor([2], dtype=torch.float) e = torch.tensor([5], dtype=torch.float) f = torch.tensor([2.5], dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) d.requires_grad_(True) e.requires_grad_(True) f.requires_grad_(True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) diag_lazy_tensor = KroneckerProductDiagLazyTensor( ConstantDiagLazyTensor(d, diag_shape=3), ConstantDiagLazyTensor(e, diag_shape=2), ConstantDiagLazyTensor(f, diag_shape=4), ) return KroneckerProductAddedDiagLazyTensor(kp_lazy_tensor, diag_lazy_tensor)
def test_inv_quad_log_det_many_vectors(self): # Forward pass actual_inv_quad = self.mat_clone.inverse().matmul(self.vecs_clone).mul( self.vecs_clone).sum() actual_log_det = self.mat_clone.logdet() with gpytorch.settings.num_trace_samples(1000): non_lazy_tsr = NonLazyTensor(self.mat) res_inv_quad, res_log_det = non_lazy_tsr.inv_quad_log_det( inv_quad_rhs=self.vecs, log_det=True) self.assertAlmostEqual(res_inv_quad.item(), actual_inv_quad.item(), places=1) self.assertAlmostEqual(res_log_det.item(), actual_log_det.item(), places=1) # Backward actual_inv_quad.backward() actual_log_det.backward() res_inv_quad.backward(retain_graph=True) res_log_det.backward() self.assertLess( torch.max((self.mat_clone.grad - self.mat.grad).abs()).item(), 1e-1) self.assertLess( torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(), 1e-1)
def test_inv_quad_log_det_many_vectors_improper(self): # Forward pass actual_inv_quad = (torch.cat([ mat.inverse().unsqueeze(0) for mat in self.mats_clone ]).matmul(self.vecs_clone).mul(self.vecs_clone).sum(2).sum(1)) actual_log_det = torch.cat( [mat.logdet().unsqueeze(0) for mat in self.mats_clone]) with gpytorch.settings.num_trace_samples( 2000), gpytorch.settings.skip_logdet_forward(True): non_lazy_tsr = NonLazyTensor(self.mats) res_inv_quad, res_log_det = non_lazy_tsr.inv_quad_log_det( inv_quad_rhs=self.vecs, log_det=True) self.assertEqual(res_inv_quad.shape, actual_inv_quad.shape) self.assertEqual(res_log_det.shape, actual_log_det.shape) self.assertLess( torch.max((res_inv_quad - actual_inv_quad).abs()).item(), 1e-1) self.assertLess(torch.max(res_log_det.abs()).item(), 1e-1) # Backward inv_quad_grad_output = torch.randn(5, dtype=torch.float) log_det_grad_output = torch.randn(5, dtype=torch.float) actual_inv_quad.backward(gradient=inv_quad_grad_output) actual_log_det.backward(gradient=log_det_grad_output) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) res_log_det.backward(gradient=log_det_grad_output) self.assertLess( torch.max((self.mats_clone.grad - self.mats.grad).abs()).item(), 1e-1) self.assertLess( torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(), 1e-1)
def test_matmul_vec_random_rectangular(self): ax = torch.randn(4, 2, 3, requires_grad=True) bx = torch.randn(4, 5, 2, requires_grad=True) cx = torch.randn(4, 6, 4, requires_grad=True) rhsx = torch.randn(4, 3 * 2 * 4, 1) rhsx = (rhsx / torch.norm(rhsx)).requires_grad_(True) ax_copy = ax.clone().detach().requires_grad_(True) bx_copy = bx.clone().detach().requires_grad_(True) cx_copy = cx.clone().detach().requires_grad_(True) rhsx_copy = rhsx.clone().detach().requires_grad_(True) kp_lazy_var = KroneckerProductLazyTensor(NonLazyTensor(ax), NonLazyTensor(bx), NonLazyTensor(cx)) res = kp_lazy_var.matmul(rhsx) actual_mat = kron(kron(ax_copy, bx_copy), cx_copy) actual = actual_mat.matmul(rhsx_copy) self.assertTrue(approx_equal(res, actual)) actual.sum().backward() res.sum().backward() self.assertTrue(approx_equal(ax_copy.grad, ax.grad)) self.assertTrue(approx_equal(bx_copy.grad, bx.grad)) self.assertTrue(approx_equal(cx_copy.grad, cx.grad)) self.assertTrue(approx_equal(rhsx_copy.grad, rhsx.grad))
def test_inv_quad_many_vectors(self): # Forward pass flattened_mats = self.mats_clone.view(-1, *self.mats_clone.shape[-2:]) actual_inv_quad = ( torch.cat([mat.inverse().unsqueeze(0) for mat in flattened_mats]) .view(self.mats_clone.shape) .matmul(self.vecs_clone) .mul(self.vecs_clone) .sum(-2) .sum(-1) ) with gpytorch.settings.num_trace_samples(2000): non_lazy_tsr = NonLazyTensor(self.mats) res_inv_quad = non_lazy_tsr.inv_quad(self.vecs) self.assertEqual(res_inv_quad.shape, actual_inv_quad.shape) self.assertLess(torch.max((res_inv_quad - actual_inv_quad).abs()).item(), 1e-1) # Backward inv_quad_grad_output = torch.randn(2, 3, dtype=torch.float) actual_inv_quad.backward(gradient=inv_quad_grad_output) res_inv_quad.backward(gradient=inv_quad_grad_output, retain_graph=True) self.assertLess(torch.max((self.mats_clone.grad - self.mats.grad).abs()).item(), 1e-1) self.assertLess(torch.max((self.vecs_clone.grad - self.vecs.grad).abs()).item(), 1e-1)
def create_lazy_tensor(self): a = torch.randn(2, 3, requires_grad=True) b = torch.randn(5, 2, requires_grad=True) c = torch.randn(6, 4, requires_grad=True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) return kp_lazy_tensor
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) return kp_lazy_tensor
def test_add_diag_single_element(self): diag = torch.tensor(1.5) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + torch.eye(self.mat.size(-1)).unsqueeze(0).mul(1.5) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([1.5]) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + torch.eye(self.mat.size(-1)).unsqueeze(0).mul(1.5) self.assertTrue(approx_equal(res, actual))
def test_add_diag_different_elements_on_diagonal(self): diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.]) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + diag.diag().unsqueeze(0) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([[1.5, 1.3, 1.2, 1.1, 2.]]) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + diag[0].diag().unsqueeze(0) self.assertTrue(approx_equal(res, actual))
def test_root_decomposition(self): # Forward root = NonLazyTensor(self.mat).root_decomposition().root.evaluate() res = root.matmul(root.transpose(-1, -2)) self.assertTrue(approx_equal(res, self.mat)) # Backward sum([mat.trace() for mat in res]).backward() sum([mat.trace() for mat in self.mat_clone]).backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def test_root_decomposition(self): # Forward root = NonLazyTensor(self.mat_var).root_decomposition() res = root.matmul(root.transpose(-1, -2)) self.assertTrue(approx_equal(res, self.mat_var)) # Backward res.trace().backward() self.mat_var_clone.trace().backward() self.assertTrue( approx_equal(self.mat_var.grad, self.mat_var_clone.grad))
def test_matmul_multiple_vecs(self): # Forward res = NonLazyTensor(self.mats).matmul(self.vecs) actual = self.mats_copy.matmul(self.vecs_copy) self.assertTrue(approx_equal(res, actual)) # Backward grad_output = torch.randn(3, 4, 5, 2) res.backward(gradient=grad_output) actual.backward(gradient=grad_output) self.assertTrue(approx_equal(self.mats_copy.grad, self.mats.grad)) self.assertTrue(approx_equal(self.vecs_copy.grad, self.vecs.grad))
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) d = 0.5 * torch.rand(24, dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) d.requires_grad_(True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) diag_lazy_tensor = DiagLazyTensor(d) return KroneckerProductAddedDiagLazyTensor(kp_lazy_tensor, diag_lazy_tensor)
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1, 0], [0.5, 4, -1, 0], [1, -1, 3, 0], [0, 0, 0, 4]], dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) kp_lazy_tensor = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b), NonLazyTensor(c)) diag_lazy_tensor = ConstantDiagLazyTensor( torch.tensor([0.25], dtype=torch.float, requires_grad=True), kp_lazy_tensor.shape[-1], ) return KroneckerProductAddedDiagLazyTensor(kp_lazy_tensor, diag_lazy_tensor)
def test_matmul_vec(self): # Forward res = NonLazyTensor(self.mat).matmul(self.vec) actual = self.mat_copy.matmul(self.vec_copy) self.assertTrue(approx_equal(res, actual)) # Backward grad_output = torch.randn(3) res.backward(gradient=grad_output) actual.backward(gradient=grad_output) self.assertTrue(approx_equal(self.mat_copy.grad, self.mat.grad)) self.assertTrue(approx_equal(self.vec_copy.grad, self.vec.grad))
def test_log_det_only(self): # Forward pass with gpytorch.settings.num_trace_samples(1000): res = NonLazyTensor(self.mat).log_det() actual = self.mat_clone.logdet() self.assertAlmostEqual(res.item(), actual.item(), places=1) # Backward actual.backward() res.backward() self.assertLess( torch.max((self.mat_clone.grad - self.mat.grad).abs()).item(), 1e-1)
def create_lazy_tensor(self): root = torch.randn(5, 3, 6, 7) self.psd_mat = root.matmul(root.transpose(-2, -1)) slice1_mat = self.psd_mat[:2, ...].requires_grad_() slice2_mat = self.psd_mat[2:3, ...].requires_grad_() slice3_mat = self.psd_mat[3:, ...].requires_grad_() slice1 = NonLazyTensor(slice1_mat) slice2 = NonLazyTensor(slice2_mat) slice3 = NonLazyTensor(slice3_mat) return CatLazyTensor(slice1, slice2, slice3, dim=0)
def create_lazy_tensor(self): root = torch.randn(3, 6, 7) self.psd_mat = root.matmul(root.transpose(-2, -1)) slice1_mat = self.psd_mat[..., :2, :].requires_grad_() slice2_mat = self.psd_mat[..., 2:4, :].requires_grad_() slice3_mat = self.psd_mat[..., 4:6, :].requires_grad_() slice1 = NonLazyTensor(slice1_mat) slice2 = NonLazyTensor(slice2_mat) slice3 = NonLazyTensor(slice3_mat) return CatLazyTensor(slice1, slice2, slice3, dim=-2)
def test_root_decomposition(self): mat = self._create_mat().detach().requires_grad_(True) mat_clone = mat.detach().clone().requires_grad_(True) # Forward root = NonLazyTensor(mat).root_decomposition().root.evaluate() res = root.matmul(root.transpose(-1, -2)) self.assertAllClose(res, mat) # Backward sum([mat.trace() for mat in res.view(-1, mat.size(-2), mat.size(-1))]).backward() sum([mat.trace() for mat in mat_clone.view(-1, mat.size(-2), mat.size(-1))]).backward() self.assertAllClose(mat.grad, mat_clone.grad)
def create_lazy_tensor(self): root = torch.randn(6, 7) self.psd_mat = root.matmul(root.t()) slice1_mat = self.psd_mat[:2, :].requires_grad_() slice2_mat = self.psd_mat[2:4, :].requires_grad_() slice3_mat = self.psd_mat[4:6, :].requires_grad_() slice1 = NonLazyTensor(slice1_mat) slice2 = NonLazyTensor(slice2_mat) slice3 = NonLazyTensor(slice3_mat) return CatLazyTensor(slice1, slice2, slice3, dim=-2)
def mask_dependent_covar(self, M1s, U1, M2s, U2, covar_xx): # Assume M1s, M2s sorted descending B = M1s.shape[:-1] M1s = M1s[..., 0] idxs1 = torch.nonzero(M1s - torch.ones_like(M1s)) idxend1 = torch.min(idxs1).item() if idxs1.numel() else M1s.size(-1) # assume sorted assert (M1s[..., idxend1:] == 0).all() U1s = U1[..., :idxend1, :] M2s = M2s[..., 0] idxs2 = torch.nonzero(M2s - torch.ones_like(M2s)) idxend2 = torch.min(idxs2).item() if idxs2.numel() else M2s.size(-1) # assume sorted assert (M2s[..., idxend2:] == 0).all() U2s = U2[..., :idxend2, :] V = ensurelazy(self.task_covar_module.V.covar_matrix) U = ensurelazy(self.task_covar_module.U.covar_matrix) Kxx = ensurelazy(covar_xx) k_xx_22 = Kxx[idxend1:, idxend2:] if k_xx_22.numel(): Kij_xx_22 = self.kernel2(k_xx_22, V, U) k_xx_11 = Kxx[:idxend1, :idxend2] if k_xx_11.numel(): H1 = BlockDiagLazyTensor(NonLazyTensor(U1s.unsqueeze(1))) H2 = BlockDiagLazyTensor(NonLazyTensor(U2s.unsqueeze(1))) Kij_xx_11 = self.kernel1(k_xx_11, H1, H2, V, U) if k_xx_11.numel() and k_xx_22.numel(): k_xx_12 = Kxx[:idxend1, idxend2:] assert k_xx_12.numel() Kij_xx_12 = self.correlation_kernel_12(k_xx_12, H1, V, U) k_xx_21 = Kxx[idxend1:, :idxend2] assert k_xx_21.numel() Kij_xx_21 = self.correlation_kernel_12(k_xx_21.t(), H2, V, U).t() Kij_xx = lazycat([ lazycat([Kij_xx_11, Kij_xx_12], dim=1), lazycat([Kij_xx_21, Kij_xx_22], dim=1) ], dim=0) #Kij_xx.evaluate() return Kij_xx elif k_xx_22.numel(): return Kij_xx_22 else: assert k_xx_11.numel() return Kij_xx_11
def test_log_det_only(self): # Forward pass with gpytorch.settings.num_trace_samples(1000): res = NonLazyTensor(self.mat_var).log_det() actual = self.mat_var_clone.det().log() self.assertAlmostEqual(res.item(), actual.item(), places=1) # Backward actual.backward() res.backward() self.assertTrue( approx_equal(self.mat_var_clone.grad, self.mat_var.grad, epsilon=1e-1))
def test_root_inv_decomposition(self): # Forward probe_vectors = torch.randn(4, 5) test_vectors = torch.randn(4, 5) root = NonLazyTensor(self.mat).root_inv_decomposition( probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) actual = self.mat_clone.inverse() self.assertTrue(approx_equal(res, actual)) # Backward res.trace().backward() actual.trace().backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def create_lazy_tensor(self): a = torch.tensor([[4, 0, 2], [0, 3, -1], [2, -1, 3]], dtype=torch.float) b = torch.tensor([[2, 1], [1, 2]], dtype=torch.float) c = torch.tensor([[4, 0.5, 1], [0.5, 4, -1], [1, -1, 3]], dtype=torch.float) d = torch.tensor([[1.2, 0.75], [0.75, 1.2]], dtype=torch.float) a.requires_grad_(True) b.requires_grad_(True) c.requires_grad_(True) d.requires_grad_(True) kp_lt_1 = KroneckerProductLazyTensor(NonLazyTensor(a), NonLazyTensor(b)) kp_lt_2 = KroneckerProductLazyTensor(NonLazyTensor(c), NonLazyTensor(d)) return SumKroneckerLazyTensor(kp_lt_1, kp_lt_2)
def test_inv_matmul_multiple_vecs(self): # Forward res = NonLazyTensor(self.mat_var).inv_matmul(self.vecs_var) actual = self.mat_var_copy.inverse().matmul(self.vecs_var_copy) self.assertTrue(approx_equal(res, actual)) # Backward grad_output = torch.randn(3, 4) res.backward(gradient=grad_output) actual.backward(gradient=grad_output) self.assertTrue(approx_equal(self.mat_var_copy.grad, self.mat_var.grad)) self.assertTrue( approx_equal(self.vecs_var_copy.grad, self.vecs_var.grad))
def test_root_inv_decomposition(self): # Forward probe_vectors = torch.randn(3, 4, 5) test_vectors = torch.randn(3, 4, 5) root = NonLazyTensor(self.mat).root_inv_decomposition( probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) actual = torch.cat( [mat.inverse().unsqueeze(0) for mat in self.mat_clone]) self.assertTrue(approx_equal(res, actual)) # Backward sum([mat.trace() for mat in res]).backward() sum([mat.trace() for mat in actual]).backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def test_add_diag_different_batch(self): diag = torch.tensor([[1.5, 1.3, 1.2, 1.1, 2.], [0.1, 0.2, 0.3, 0.4, 0.], [0., 0.1, 1.3, 1.4, 0.]]) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + torch.cat([ diag[0].diag().unsqueeze(0), diag[1].diag().unsqueeze(0), diag[2].diag().unsqueeze(0) ]) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([[1.5], [1.3], [0.1]]) res = NonLazyTensor(self.mat).add_diag(diag).evaluate() actual = self.mat + torch.eye(5).unsqueeze(0) * diag.unsqueeze(-1) self.assertTrue(approx_equal(res, actual))
def create_lazy_tensor(self): tensor = torch.randn(3, 5, 5) tensor = tensor.transpose(-1, -2).matmul(tensor).detach() diag = torch.tensor( [[1.0, 2.0, 4.0, 2.0, 3.0], [2.0, 1.0, 2.0, 1.0, 4.0], [1.0, 2.0, 2.0, 3.0, 4.0]], requires_grad=True ) return AddedDiagLazyTensor(NonLazyTensor(tensor), DiagLazyTensor(diag))