示例#1
0
 def test_linear_solve_matrix_tape(self):
     y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2)
     x0 = CenteredGrid(0, extrapolation.ZERO, x=3)
     for method in ['CG', 'CG-adaptive', 'auto']:
         solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
         with math.SolveTape() as solves:
             x = field.solve_linear(math.jit_compile_linear(field.laplace),
                                    y, solve)
         math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]],
                           abs_tolerance=1e-3)
         assert len(solves) == 1
         assert solves[0] == solves[solve]
         math.assert_close(solves[solve].residual.values,
                           0,
                           abs_tolerance=1e-3)
         assert math.close(solves[solve].iterations, 2) or math.close(
             solves[solve].iterations, -1)
         with math.SolveTape(record_trajectories=True) as solves:
             x = field.solve_linear(math.jit_compile_linear(field.laplace),
                                    y, solve)
         math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]],
                           abs_tolerance=1e-3)
         assert solves[solve].x.trajectory.size == 3
         math.assert_close(solves[solve].residual.trajectory[-1].values,
                           0,
                           abs_tolerance=1e-3)
示例#2
0
 def solve(y, method):
     print(f"Tracing {method} with {backend}...")
     solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
     with SolveTape() as solves:
         x = field.solve_linear(math.jit_compile_linear(field.laplace),
                                y, solve)
     return x
示例#3
0
 def test_sparse_matrix(self):
     for backend in BACKENDS:
         with backend:
             for f in ['csr', 'csc', 'coo']:
                 matrix = math.jit_compile_linear(
                     math.laplace).sparse_matrix(math.zeros(spatial(x=5)),
                                                 format=f)
                 self.assertEqual(f, matrix.indexing_type)
                 self.assertEqual((5, 5), matrix.shape)
示例#4
0
 def test_solve_diverge(self):
     y = math.ones(spatial(x=2)) * (1, 2)
     x0 = math.zeros(spatial(x=2))
     for method in ['CG']:
         solve = Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
         try:
             field.solve_linear(math.jit_compile_linear(math.laplace), y,
                                solve)
             assert False
         except Diverged:
             pass
         with math.SolveTape(record_trajectories=True) as solves:
             try:
                 field.solve_linear(math.jit_compile_linear(math.laplace),
                                    y, solve)  # impossible
                 assert False
             except Diverged:
                 pass
示例#5
0
 def test_linear_solve_matrix_batched(
         self):  # TODO also test batched matrix
     y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2)
     x0 = CenteredGrid(0, extrapolation.ZERO, x=3)
     for method in ['CG', 'CG-adaptive', 'auto']:
         solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
         x = field.solve_linear(math.jit_compile_linear(field.laplace), y,
                                solve)
         math.assert_close(x.values, [[-1.5, -2, -1.5], [-3, -4, -3]],
                           abs_tolerance=1e-3)
示例#6
0
 def test_solve_linear_matrix_dirichlet(self):
     for backend in BACKENDS:
         with backend:
             y = CenteredGrid(1, extrapolation.ONE, x=3)
             x0 = CenteredGrid(0, extrapolation.ONE, x=3)
             solve = math.Solve('CG', 0, 1e-3, x0=x0, max_iterations=100)
             x_ref = field.solve_linear(field.laplace, y, solve)
             x_jit = field.solve_linear(
                 math.jit_compile_linear(field.laplace), y, solve)
             math.assert_close(x_ref.values,
                               x_jit.values, [-0.5, -1, -0.5],
                               abs_tolerance=1e-3,
                               msg=backend)
示例#7
0
 def test_solve_linear_function_batched(self):
     y = CenteredGrid(1, extrapolation.ZERO, x=3) * (1, 2)
     x0 = CenteredGrid(0, extrapolation.ZERO, x=3)
     for method in ['CG', 'CG-adaptive', 'auto']:
         solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
         x = field.solve_linear(math.jit_compile_linear(field.laplace), y,
                                solve)
         math.assert_close(x.values,
                           math.wrap([[-1.5, -2, -1.5], [-3, -4, -3]],
                                     channel('vector'), spatial('x')),
                           abs_tolerance=1e-3)
         with math.SolveTape() as solves:
             x = field.solve_linear(math.jit_compile_linear(field.laplace),
                                    y, solve)
         math.assert_close(x.values,
                           math.wrap([[-1.5, -2, -1.5], [-3, -4, -3]],
                                     channel('vector'), spatial('x')),
                           abs_tolerance=1e-3)
         assert len(solves) == 1
         assert solves[0] == solves[solve]
         math.assert_close(solves[solve].residual.values,
                           0,
                           abs_tolerance=1e-3)
示例#8
0
 def test_solve_linear_matrix(self):
     for backend in BACKENDS:
         with backend:
             y = CenteredGrid(1, extrapolation.ZERO, x=3)
             x0 = CenteredGrid(0, extrapolation.ZERO, x=3)
             for method in ['CG', 'CG-adaptive', 'auto']:
                 solve = math.Solve(method,
                                    0,
                                    1e-3,
                                    x0=x0,
                                    max_iterations=100)
                 x = field.solve_linear(
                     math.jit_compile_linear(field.laplace), y, solve)
                 math.assert_close(x.values, [-1.5, -2, -1.5],
                                   abs_tolerance=1e-3,
                                   msg=backend)
示例#9
0
    def test_jit_compile_linear(self):
        math.GLOBAL_AXIS_ORDER.x_last()
        x = math.random_normal(batch(batch=3)
                               & spatial(x=4, y=3))  # , vector=2

        def linear_function(val):
            val = -val
            val *= 2
            val = math.pad(val, {
                'x': (2, 0),
                'y': (0, 1)
            }, math.extrapolation.PERIODIC)
            val = val.x[:-2].y[1:] + val.x[2:].y[:-1]
            val = math.pad(val, {
                'x': (0, 0),
                'y': (0, 1)
            }, math.extrapolation.ZERO)
            val = math.pad(val, {
                'x': (2, 2),
                'y': (0, 1)
            }, math.extrapolation.BOUNDARY)
            return math.sum([val, val], dim='0') - val

        functions = [
            linear_function,
            lambda val: math.spatial_gradient(val,
                                              difference='forward',
                                              padding=math.extrapolation.ZERO,
                                              dims='x').gradient[0],
            lambda val: math.spatial_gradient(val,
                                              difference='backward',
                                              padding=math.extrapolation.
                                              PERIODIC,
                                              dims='x').gradient[0],
            lambda val: math.spatial_gradient(val,
                                              difference='central',
                                              padding=math.extrapolation.
                                              BOUNDARY,
                                              dims='x').gradient[0],
        ]
        for f in functions:
            direct_result = f(x)
            jit_f = math.jit_compile_linear(f)
            jit_result = jit_f(x)
            math.assert_close(direct_result, jit_result)
示例#10
0
 def solve(y, method):
     solve = math.Solve(method, 0, 1e-3, x0=x0, max_iterations=100)
     return field.solve_linear(math.jit_compile_linear(field.laplace),
                               y, solve)