def test_solve_qr(self, dtype=torch.float64, tol=1e-8): size = 50 X = torch.rand((size, 2)).to(dtype=dtype) y = torch.sin(torch.sum(X, 1)).unsqueeze(-1).to(dtype=dtype) with settings.min_preconditioning_size(0): noise = torch.DoubleTensor(size).uniform_( math.log(1e-3), math.log(1e-1)).exp_().to(dtype=dtype) lazy_tsr = RBFKernel().to( dtype=dtype)(X).evaluate_kernel().add_diag(noise) precondition_qr, _, logdet_qr = lazy_tsr._preconditioner() F = lazy_tsr._piv_chol_self M = noise.diag() + F.matmul(F.t()) x_exact = torch.solve(y, M)[0] x_qr = precondition_qr(y) self.assertTrue(approx_equal(x_exact, x_qr, tol)) logdet = 2 * torch.cholesky(M).diag().log().sum(-1) self.assertTrue(approx_equal(logdet, logdet_qr, tol))
def test_lanczos(self): size = 100 matrix = torch.randn(size, size) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.ones(matrix.size(-1)).mul(1e-6).diag()) q_mat, t_mat = lanczos_tridiag( matrix.matmul, max_iter=size, dtype=matrix.dtype, device=matrix.device, matrix_shape=matrix.shape ) approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2)) self.assertTrue(approx_equal(approx, matrix))
def lanczos_tridiag_test(self, matrix): size = matrix.shape[0] q_mat, t_mat = lanczos_tridiag(matrix.matmul, max_iter=size, dtype=matrix.dtype, device=matrix.device, matrix_shape=matrix.shape) self.assert_valid_sizes(size, t_mat, q_mat) self.assert_tridiagonally_positive(t_mat) approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2)) self.assertTrue(approx_equal(approx, matrix))
def test_solve(self): size = 100 train_x = torch.linspace(0, 1, size) covar_matrix = RBFKernel()(train_x, train_x).evaluate() piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10) woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor( piv_chol, piv_chol, torch.ones(100), logdet=True) self.assertTrue( approx_equal(logdet, (piv_chol @ piv_chol.transpose(-1, -2) + torch.eye(100)).logdet(), 2e-4)) rhs_vector = torch.randn(100, 50) shifted_covar_matrix = covar_matrix + torch.eye(size) real_solve = shifted_covar_matrix.inverse().matmul(rhs_vector) scaled_inv_diag = (inv_scale / torch.ones(100)).unsqueeze(-1) approx_solve = woodbury.woodbury_solve(rhs_vector, piv_chol * scaled_inv_diag, woodbury_factor, scaled_inv_diag, inv_scale) self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
def test_interpolation(self): x = torch.linspace(0.01, 1, 100).unsqueeze(1) grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1) indices, values = Interpolation().interpolate(grid, x) indices = indices.squeeze_(0) values = values.squeeze_(0) test_func_grid = grid.squeeze(1).pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze() self.assertTrue(approx_equal(interp_func_x, test_func_x))
def test_smoothed_box_prior_log_prob_log_transform(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") a, b = torch.zeros(2, device=device), torch.ones(2, device=device) sigma = 0.1 prior = SmoothedBoxPrior(a, b, sigma, transform=torch.exp) t = torch.tensor([0.5, 1.1], device=device).log() self.assertAlmostEqual(prior.log_prob(t).item(), -0.9473, places=4) t = torch.tensor([[0.5, 1.1], [0.1, 0.25]], device=device).log() log_prob_expected = torch.tensor([-0.947347, -0.447347], device=t.device) self.assertTrue(torch.all(approx_equal(prior.log_prob(t), log_prob_expected))) with self.assertRaises(RuntimeError): prior.log_prob(torch.ones(3, device=device))
def test_lkj_prior_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJPrior(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) self.assertAlmostEqual(prior.log_prob(S).item(), -1.86942, places=4) S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-1.86942, -1.72558], device=S.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device)) # For eta=1.0 log_prob is flat over all covariance matrices prior = LKJPrior(2, torch.tensor(1.0, device=device)) self.assertTrue(torch.all(prior.log_prob(S) == prior.C))
def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device)) dist = LKJCholesky(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) S_chol = torch.linalg.cholesky(S) self.assertAlmostEqual(prior.log_prob(S_chol), dist.log_prob(S_chol), places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)]) S_chol = torch.stack([torch.linalg.cholesky(Si) for Si in S]) self.assertTrue( approx_equal(prior.log_prob(S_chol), dist.log_prob(S_chol)))
def test_toeplitz_matmul_batch(self): cols = torch.tensor([[1, 6, 4, 5], [2, 3, 1, 0], [1, 2, 3, 1]], dtype=torch.float) rows = torch.tensor([[1, 2, 1, 1], [2, 0, 0, 1], [1, 5, 1, 0]], dtype=torch.float) rhs_mats = torch.randn(3, 4, 2) # Actual lhs_mats = torch.zeros(3, 4, 4) for i, (col, row) in enumerate(zip(cols, rows)): lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row)) actual = torch.matmul(lhs_mats, rhs_mats) # Fast toeplitz res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats) self.assertTrue(approx_equal(res, actual))
def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") a = torch.tensor([exp(-1), exp(-2)], device=device) b = torch.tensor([exp(1), exp(2)], device=device) sd_prior = SmoothedBoxPrior(a, b) prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device), sd_prior) S = torch.eye(2, device=device) self.assertAlmostEqual(prior.log_prob(S).item(), -4.71958, places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) self.assertTrue( approx_equal(prior.log_prob(S), torch.tensor([-4.71958, -4.57574], device=S.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device)) # For eta=1.0 log_prob is flat over all covariance matrices prior = LKJCovariancePrior(2, torch.tensor(1.0, device=device), sd_prior) marginal_sd = torch.diagonal(S, dim1=-2, dim2=-1).sqrt() log_prob_expected = prior.correlation_prior.C + prior.sd_prior.log_prob( marginal_sd) self.assertTrue(approx_equal(prior.log_prob(S), log_prob_expected))
def test_lkj_prior_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJPrior(2, torch.tensor(0.5, device=device)) dist = LKJCholesky(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) S_chol = torch.linalg.cholesky(S) self.assertAlmostEqual(prior.log_prob(S), dist.log_prob(S_chol), places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) S_chol = torch.linalg.cholesky(S) self.assertTrue(approx_equal(prior.log_prob(S), dist.log_prob(S_chol))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device))
def test_get_item_tensor_index_on_batch(self): # Tests the default LV.__getitem__ behavior lazy_tensor = ZeroLazyTensor(3, 5, 5) evaluated = lazy_tensor.evaluate() index = (torch.tensor([0, 1, 1, 0]), torch.tensor([0, 1, 0, 2]), torch.tensor([1, 2, 0, 1])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 1, 1, 0]), torch.tensor([0, 1, 0, 2]), slice(None, None, None)) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 1, 1]), slice(None, None, None), torch.tensor([0, 1, 2])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (slice(None, None, None), torch.tensor([0, 1, 1, 0]), torch.tensor([0, 1, 0, 2])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 0, 1, 1]), slice(None, None, None), slice(None, None, None)) self.assertTrue( approx_equal(lazy_tensor[index].evaluate(), evaluated[index])) index = (slice(None, None, None), torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 1, 1, 0]), torch.tensor([0, 1, 0, 2]), slice(None, None, None)) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (torch.tensor([0, 0, 1, 0]), slice(None, None, None), torch.tensor([0, 0, 1, 1])) self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index])) index = (Ellipsis, torch.tensor([0, 1, 1, 0])) self.assertTrue( approx_equal(lazy_tensor[index].evaluate(), evaluated[index]))
def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") a = torch.tensor([exp(-1), exp(-2)], device=device) b = torch.tensor([exp(1), exp(2)], device=device) sd_prior = SmoothedBoxPrior(a, b) prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device), sd_prior) corr_dist = LKJCholesky(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob( S.diag()).sum() self.assertAlmostEqual(prior.log_prob(S), dist_log_prob, places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) S_chol = torch.linalg.cholesky(S) dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob( torch.diagonal(S, dim1=-2, dim2=-1)) self.assertTrue(approx_equal(prior.log_prob(S), dist_log_prob))
def test_sample(self): a = torch.as_tensor(0.0) b = torch.as_tensor(1.0) sigma = 0.01 gauss_max = 1 / (math.sqrt(2 * math.pi) * sigma) ratio_gaussian_mass = 1 / (gauss_max * (b - a) + 1) prior = SmoothedBoxPrior(a, b, sigma) n_samples = 50000 samples = prior.sample((n_samples, )) gaussian_idx = (samples < a) | (samples > b) gaussian_samples = samples[gaussian_idx] n_gaussian = gaussian_samples.shape[0] self.assertTrue( torch.all( approx_equal(torch.as_tensor(n_gaussian / n_samples), ratio_gaussian_mass, epsilon=0.005)))
def test_add_diag(self): diag = torch.tensor(1.5) res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate() actual = torch.eye(5).mul(1.5) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([1.5]) res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate() actual = torch.eye(5).mul(1.5) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.0]) res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate() actual = diag.diag() self.assertTrue(approx_equal(res, actual)) diag = torch.tensor(1.5) res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate() actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([1.5]) res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate() actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.0]) res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate() actual = diag.diag().unsqueeze(0).repeat(2, 1, 1) self.assertTrue(approx_equal(res, actual)) diag = torch.tensor([[1.5, 1.3, 1.2, 1.1, 2.0], [0, 1, 2, 1, 1]]) res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate() actual = torch.cat( [diag[0].diag().unsqueeze(0), diag[1].diag().unsqueeze(0)]) self.assertTrue(approx_equal(res, actual))
def test_left_interp_on_a_vector(self): vector = torch.randn(6) res = left_interp(self.interp_indices, self.interp_values, vector) actual = torch.matmul(self.interp_matrix, vector) self.assertTrue(approx_equal(res, actual))
def test_multidim_interpolation(self): x = torch.tensor([[0.25, 0.45, 0.65, 0.85], [0.35, 0.375, 0.4, 0.425], [0.45, 0.5, 0.55, 0.6]]).t().contiguous() grid = torch.linspace(0.0, 1.0, 11).unsqueeze(1).repeat(1, 3) indices, values = Interpolation().interpolate(grid, x) actual_indices = torch.cat( [ torch.tensor( [ [ 146, 147, 148, 149, 157, 158, 159, 160, 168, 169, 170, 171, 179 ], [ 389, 390, 391, 392, 400, 401, 402, 403, 411, 412, 413, 414, 422 ], [ 642, 643, 644, 645, 653, 654, 655, 656, 664, 665, 666, 667, 675 ], [ 885, 886, 887, 888, 896, 897, 898, 899, 907, 908, 909, 910, 918 ], ], dtype=torch.long, ), torch.tensor( [ [ 180, 181, 182, 267, 268, 269, 270, 278, 279, 280, 281, 289, 290 ], [ 423, 424, 425, 510, 511, 512, 513, 521, 522, 523, 524, 532, 533 ], [ 676, 677, 678, 763, 764, 765, 766, 774, 775, 776, 777, 785, 786 ], [ 919, 920, 921, 1006, 1007, 1008, 1009, 1017, 1018, 1019, 1020, 1028, 1029 ], ], dtype=torch.long, ), torch.tensor( [ [ 291, 292, 300, 301, 302, 303, 388, 389, 390, 391, 399, 400, 401 ], [ 534, 535, 543, 544, 545, 546, 631, 632, 633, 634, 642, 643, 644 ], [ 787, 788, 796, 797, 798, 799, 884, 885, 886, 887, 895, 896, 897 ], [ 1030, 1031, 1039, 1040, 1041, 1042, 1127, 1128, 1129, 1130, 1138, 1139, 1140 ], ], dtype=torch.long, ), torch.tensor( [ [ 402, 410, 411, 412, 413, 421, 422, 423, 424, 509, 510, 511, 512 ], [ 645, 653, 654, 655, 656, 664, 665, 666, 667, 752, 753, 754, 755 ], [ 898, 906, 907, 908, 909, 917, 918, 919, 920, 1005, 1006, 1007, 1008 ], [ 1141, 1149, 1150, 1151, 1152, 1160, 1161, 1162, 1163, 1248, 1249, 1250, 1251 ], ], dtype=torch.long, ), torch.tensor( [ [ 520, 521, 522, 523, 531, 532, 533, 534, 542, 543, 544, 545 ], [ 763, 764, 765, 766, 774, 775, 776, 777, 785, 786, 787, 788 ], [ 1016, 1017, 1018, 1019, 1027, 1028, 1029, 1030, 1038, 1039, 1040, 1041 ], [ 1259, 1260, 1261, 1262, 1270, 1271, 1272, 1273, 1281, 1282, 1283, 1284 ], ], dtype=torch.long, ), ], 1, ) self.assertTrue(approx_equal(indices, actual_indices)) actual_values = torch.cat( [ torch.tensor([ [ -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022, -0.0198 ], [ 0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000, -0.0000, -0.0542 ], [ 0.0000, -0.0000, -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000, -0.0000 ], [ 0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000, -0.0000, -0.0142 ], ]), torch.tensor([ [ -0.0198, 0.0022, -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022 ], [ -0.0000, -0.0000, 0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000, -0.0000 ], [ -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000, 0.0000, 0.0000, -0.0000 ], [ -0.0000, -0.0000, 0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000, -0.0000 ], ]), torch.tensor([ [ -0.0198, 0.1780, 0.1780, -0.0198, -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198 ], [ 0.0000, 0.1274, 0.0000, 0.0000, 0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396 ], [ -0.0352, 0.3164, 0.3164, -0.0352, -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000 ], [ 0.0000, 0.4878, 0.0000, 0.0000, 0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132 ], ]), torch.tensor([ [ -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022, -0.0198, 0.1780, 0.1780, -0.0198 ], [ -0.0000, -0.0000, -0.0000, -0.0132, -0.0000, -0.0000, 0.0000, 0.1274, 0.0000, 0.0000 ], [ 0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, -0.0352, 0.3164, 0.3164, -0.0352 ], [ -0.0000, -0.0000, -0.0000, -0.0396, -0.0000, -0.0000, 0.0000, 0.4878, 0.0000, 0.0000 ], ]), torch.tensor([ [ -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198, -0.0198, 0.0022, -0.0002, 0.0022 ], [ 0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396, -0.0000, -0.0000, 0.0000, 0.0015 ], [ -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000, 0.0000, -0.0000 ], [ 0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132, -0.0000, -0.0000, 0.0000, 0.0044 ], ]), torch.tensor([ [ 0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022 ], [ 0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000, -0.0000, -0.0542, -0.0000, -0.0000 ], [ -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039, 0.0000, -0.0000, -0.0000, 0.0000 ], [ 0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000, -0.0000, -0.0142, -0.0000, -0.0000 ], ]), torch.tensor([ [-0.0002, 0.0022, 0.0022, -0.0002], [0.0000, 0.0044, 0.0000, 0.0000], [0.0000, -0.0000, -0.0000, 0.0000], [0.0000, 0.0015, 0.0000, 0.0000], ]), ], 1, ) self.assertTrue(approx_equal(values, actual_values))
def test_solve(self): size = 100 train_x = torch.cat( [ torch.linspace(0, 1, size).unsqueeze(0), torch.linspace(0, 0.5, size).unsqueeze(0), torch.linspace(0, 0.25, size).unsqueeze(0), torch.linspace(0, 1.25, size).unsqueeze(0), torch.linspace(0, 1.5, size).unsqueeze(0), torch.linspace(0, 1, size).unsqueeze(0), torch.linspace(0, 0.5, size).unsqueeze(0), torch.linspace(0, 0.25, size).unsqueeze(0), torch.linspace(0, 1.25, size).unsqueeze(0), torch.linspace(0, 1.25, size).unsqueeze(0), torch.linspace(0, 1.5, size).unsqueeze(0), torch.linspace(0, 1, size).unsqueeze(0), ], 0, ).unsqueeze(-1) covar_matrix = RBFKernel()(train_x, train_x).evaluate().view( 2, 2, 3, size, size) piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10) woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor( piv_chol, piv_chol, torch.ones(2, 2, 3, 100), logdet=True) actual_logdet = torch.stack([ mat.logdet() for mat in (piv_chol @ piv_chol.transpose(-1, -2) + torch.eye(100)).view(-1, 100, 100) ], 0).view(2, 2, 3) self.assertTrue(approx_equal(logdet, actual_logdet, 2e-4)) rhs_vector = torch.randn(2, 2, 3, 100, 5) shifted_covar_matrix = covar_matrix + torch.eye(size) real_solve = torch.cat( [ shifted_covar_matrix[0, 0, 0].inverse().matmul( rhs_vector[0, 0, 0]).unsqueeze(0), shifted_covar_matrix[0, 0, 1].inverse().matmul( rhs_vector[0, 0, 1]).unsqueeze(0), shifted_covar_matrix[0, 0, 2].inverse().matmul( rhs_vector[0, 0, 2]).unsqueeze(0), shifted_covar_matrix[0, 1, 0].inverse().matmul( rhs_vector[0, 1, 0]).unsqueeze(0), shifted_covar_matrix[0, 1, 1].inverse().matmul( rhs_vector[0, 1, 1]).unsqueeze(0), shifted_covar_matrix[0, 1, 2].inverse().matmul( rhs_vector[0, 1, 2]).unsqueeze(0), shifted_covar_matrix[1, 0, 0].inverse().matmul( rhs_vector[1, 0, 0]).unsqueeze(0), shifted_covar_matrix[1, 0, 1].inverse().matmul( rhs_vector[1, 0, 1]).unsqueeze(0), shifted_covar_matrix[1, 0, 2].inverse().matmul( rhs_vector[1, 0, 2]).unsqueeze(0), shifted_covar_matrix[1, 1, 0].inverse().matmul( rhs_vector[1, 1, 0]).unsqueeze(0), shifted_covar_matrix[1, 1, 1].inverse().matmul( rhs_vector[1, 1, 1]).unsqueeze(0), shifted_covar_matrix[1, 1, 2].inverse().matmul( rhs_vector[1, 1, 2]).unsqueeze(0), ], 0, ).view_as(rhs_vector) scaled_inv_diag = (inv_scale / torch.ones(2, 3, 100)).unsqueeze(-1) approx_solve = woodbury.woodbury_solve(rhs_vector, piv_chol * scaled_inv_diag, woodbury_factor, scaled_inv_diag, inv_scale) self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
def test_left_t_interp_on_a_vector(self): vector = torch.randn(9) res = left_t_interp(self.interp_indices, self.interp_values, vector, 6) actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector) self.assertTrue(approx_equal(res, actual))
def test_left_t_interp_on_a_matrix(self): matrix = torch.randn(9, 3) res = left_t_interp(self.interp_indices, self.interp_values, matrix, 6) actual = torch.matmul(self.interp_matrix.transpose(-1, -2), matrix) self.assertTrue(approx_equal(res, actual))