def test_batch_cg_with_tridiag(self): batch = 5 size = 10 matrix = torch.randn(batch, size, size, dtype=torch.float64) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1)) rhs = torch.randn(batch, size, 10, dtype=torch.float64) solves, t_mats = linear_cg(matrix.matmul, rhs=rhs, n_tridiag=8, max_iter=size, max_tridiag_iter=10, tolerance=0, eps=1e-30) # Check cg matrix_chol = torch.cholesky(matrix) actual = cholesky_solve(rhs, matrix_chol) self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4)) # Check tridiag for i in range(5): eigs = matrix[i].symeig()[0] for j in range(8): approx_eigs = t_mats[j, i].symeig()[0] self.assertTrue( torch.allclose(eigs, approx_eigs, atol=1e-3, rtol=1e-4))
def test_cg_with_tridiag(self): size = 10 matrix = torch.randn(size, size, dtype=torch.float64) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1)) rhs = torch.randn(size, 50, dtype=torch.float64) solves, t_mats = linear_cg(matrix.matmul, rhs=rhs, n_tridiag=5, max_tridiag_iter=10, max_iter=size, tolerance=0, eps=1e-15) # Check cg matrix_chol = matrix.cholesky() actual = cholesky_solve(rhs, matrix_chol) self.assertTrue(approx_equal(solves, actual)) # Check tridiag eigs = matrix.symeig()[0] for i in range(5): approx_eigs = t_mats[i].symeig()[0] self.assertTrue(approx_equal(eigs, approx_eigs))
def test_cg(self): size = 100 matrix = torch.randn(size, size, dtype=torch.float64) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1)) rhs = torch.randn(size, 50, dtype=torch.float64) solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size) # Check cg matrix_chol = matrix.cholesky() actual = cholesky_solve(rhs, matrix_chol) self.assertTrue(torch.allclose(solves, actual, atol=1e-3, rtol=1e-4))
def test_batch_cg(self): batch = 5 size = 100 matrix = torch.randn(batch, size, size, dtype=torch.float64) matrix = matrix.matmul(matrix.transpose(-1, -2)) matrix.div_(matrix.norm()) matrix.add_(torch.eye(matrix.size(-1), dtype=torch.float64).mul_(1e-1)) rhs = torch.randn(batch, size, 50, dtype=torch.float64) solves = linear_cg(matrix.matmul, rhs=rhs, max_iter=size) # Check cg matrix_chol = torch.cholesky(matrix) actual = cholesky_solve(rhs, matrix_chol) self.assertTrue(approx_equal(solves, actual))