def test_linear_solve_matrix_tape(self): y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2) x0 = CenteredGrid(0, extrapolation.ZERO, x=3) for method in ['CG', 'CG-adaptive', 'auto']: solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) with math.SolveTape() as solves: x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]], abs_tolerance=1e-3) assert len(solves) == 1 assert solves[0] == solves[solve] math.assert_close(solves[solve].residual.values, 0, abs_tolerance=1e-3) assert math.close(solves[solve].iterations, 2) or math.close( solves[solve].iterations, -1) with math.SolveTape(record_trajectories=True) as solves: x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]], abs_tolerance=1e-3) assert solves[solve].x.trajectory.size == 3 math.assert_close(solves[solve].residual.trajectory[-1].values, 0, abs_tolerance=1e-3)
def solve(y, method): print(f"Tracing {method} with {backend}...") solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) with SolveTape() as solves: x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) return x
def test_sparse_matrix(self): for backend in BACKENDS: with backend: for f in ['csr', 'csc', 'coo']: matrix = math.jit_compile_linear( math.laplace).sparse_matrix(math.zeros(spatial(x=5)), format=f) self.assertEqual(f, matrix.indexing_type) self.assertEqual((5, 5), matrix.shape)
def test_solve_diverge(self): y = math.ones(spatial(x=2)) * (1, 2) x0 = math.zeros(spatial(x=2)) for method in ['CG']: solve = Solve(method, 0, 1e-3, x0=x0, max_iterations=100) try: field.solve_linear(math.jit_compile_linear(math.laplace), y, solve) assert False except Diverged: pass with math.SolveTape(record_trajectories=True) as solves: try: field.solve_linear(math.jit_compile_linear(math.laplace), y, solve) # impossible assert False except Diverged: pass
def test_linear_solve_matrix_batched( self): # TODO also test batched matrix y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2) x0 = CenteredGrid(0, extrapolation.ZERO, x=3) for method in ['CG', 'CG-adaptive', 'auto']: solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]], abs_tolerance=1e-3)
def test_solve_linear_matrix_dirichlet(self): for backend in BACKENDS: with backend: y = CenteredGrid(1, extrapolation.ONE, x=3) x0 = CenteredGrid(0, extrapolation.ONE, x=3) solve = math.Solve('CG', 0, 1e-3, x0=x0, max_iterations=100) x_ref = field.solve_linear(field.laplace, y, solve) x_jit = field.solve_linear( math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x_ref.values, x_jit.values, [-0.5, -1, -0.5], abs_tolerance=1e-3, msg=backend)
def test_solve_linear_function_batched(self): y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2) x0 = CenteredGrid(0, extrapolation.ZERO, x=3) for method in ['CG', 'CG-adaptive', 'auto']: solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, math.wrap([[-1.5, -2, -1.5], [-3, -4, -3]], channel('vector'), spatial('x')), abs_tolerance=1e-3) with math.SolveTape() as solves: x = field.solve_linear(math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, math.wrap([[-1.5, -2, -1.5], [-3, -4, -3]], channel('vector'), spatial('x')), abs_tolerance=1e-3) assert len(solves) == 1 assert solves[0] == solves[solve] math.assert_close(solves[solve].residual.values, 0, abs_tolerance=1e-3)
def test_solve_linear_matrix(self): for backend in BACKENDS: with backend: y = CenteredGrid(1, extrapolation.ZERO, x=3) x0 = CenteredGrid(0, extrapolation.ZERO, x=3) for method in ['CG', 'CG-adaptive', 'auto']: solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) x = field.solve_linear( math.jit_compile_linear(field.laplace), y, solve) math.assert_close(x.values, [-1.5, -2, -1.5], abs_tolerance=1e-3, msg=backend)
def test_jit_compile_linear(self): math.GLOBAL_AXIS_ORDER.x_last() x = math.random_normal(batch(batch=3) & spatial(x=4, y=3)) # , vector=2 def linear_function(val): val = -val val *= 2 val = math.pad(val, { 'x': (2, 0), 'y': (0, 1) }, math.extrapolation.PERIODIC) val = val.x[:-2].y[1:] + val.x[2:].y[:-1] val = math.pad(val, { 'x': (0, 0), 'y': (0, 1) }, math.extrapolation.ZERO) val = math.pad(val, { 'x': (2, 2), 'y': (0, 1) }, math.extrapolation.BOUNDARY) return math.sum([val, val], dim='0') - val functions = [ linear_function, lambda val: math.spatial_gradient(val, difference='forward', padding=math.extrapolation.ZERO, dims='x').gradient[0], lambda val: math.spatial_gradient(val, difference='backward', padding=math.extrapolation. PERIODIC, dims='x').gradient[0], lambda val: math.spatial_gradient(val, difference='central', padding=math.extrapolation. BOUNDARY, dims='x').gradient[0], ] for f in functions: direct_result = f(x) jit_f = math.jit_compile_linear(f) jit_result = jit_f(x) math.assert_close(direct_result, jit_result)
def solve(y, method): solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100) return field.solve_linear(math.jit_compile_linear(field.laplace), y, solve)