Пример #1
0
    def test_numeric_constant(self):
        grid = Grid(shape=(10, 10))

        u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)

        eq = Eq(u.forward, u.dx.dx + 0.3 * u.dy.dx)
        leq = collect_derivatives.func([eq])[0]

        assert len(leq.find(Derivative)) == 3
Пример #2
0
    def test_nocollection_if_unworthy(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing

        u = TimeFunction(name="u", grid=grid)

        eq = Eq(u.forward, (0.4 + dt) * (u.dx + u.dy))
        leq = collect_derivatives.func([eq])[0]

        assert eq == leq
Пример #3
0
    def test_symbolic_constant(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing

        u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)

        eq = Eq(u.forward, u.dx.dx + dt**0.2 * u.dy.dx)
        leq = collect_derivatives.func([eq])[0]

        assert len(leq.find(Derivative)) == 3
Пример #4
0
    def test_nocollection_staggered(self):
        grid = Grid(shape=(10, 10))
        x, y = grid.dimensions

        u = TimeFunction(name="u", grid=grid)
        v = TimeFunction(name="v", grid=grid, staggered=x)

        eq = Eq(u.forward, u.dx + v.dx)
        leq = collect_derivatives.func([eq])[0]

        assert eq == leq
Пример #5
0
    def test_nocollection_subdims(self):
        grid = Grid(shape=(10, 10))
        xi, yi = grid.interior.dimensions

        u = TimeFunction(name="u", grid=grid)
        v = TimeFunction(name="v", grid=grid)
        f = Function(name='f', grid=grid)

        eq = Eq(u.forward, u.dx + 0.2 * f[xi, yi] * v.dx)
        leq = collect_derivatives.func([eq])[0]

        assert eq == leq
Пример #6
0
    def test_symbolic_constant_times_add(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing

        u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
        f = Function(name='f', grid=grid)

        eq = Eq(u.forward, u.laplace + dt**0.2 * u.biharmonic(1 / f))
        leq = collect_derivatives.func([eq])[0]

        assert len(eq.rhs.args) == 3
        assert len(leq.rhs.args) == 2
        assert all(isinstance(i, Derivative) for i in leq.rhs.args)
Пример #7
0
    def test_pull_and_collect_nested_v3(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing
        hx, hy = grid.spacing_symbols

        a = Function(name="a", grid=grid, space_order=2)
        u = TimeFunction(name="u", grid=grid, space_order=2)
        v = TimeFunction(name="v", grid=grid, space_order=2)

        eq = Eq(u.forward, 0.4 + a * (hx + dt * (u.dx + v.dx)))
        leq = collect_derivatives.func([eq])[0]

        assert eq != leq
        assert leq.rhs == 0.4 + a * (hx + (dt * u + dt * v).dx)
Пример #8
0
    def test_pull_and_collect_nested_v2(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing
        hx, hy = grid.spacing_symbols

        u = TimeFunction(name="u", grid=grid, space_order=2)
        v = TimeFunction(name="v", grid=grid, space_order=2)

        eq = Eq(u.forward,
                ((0.4 + dt * (hy + 1. + hx * hy)) * u.dx + 0.3) * hx + v.dx)
        leq = collect_derivatives.func([eq])[0]

        assert eq != leq
        assert leq.rhs == 0.3 * hx + (hx * (0.4 + dt *
                                            (hy + 1. + hx * hy)) * u + v).dx
Пример #9
0
    def test_pull_and_collect(self):
        grid = Grid(shape=(10, 10))
        dt = grid.time_dim.spacing
        hx, _ = grid.spacing_symbols

        u = TimeFunction(name="u", grid=grid)
        v = TimeFunction(name="v", grid=grid)

        eq = Eq(u.forward, ((0.4 + dt) * u.dx + 0.3) * hx + v.dx)
        leq = collect_derivatives.func([eq])[0]

        assert eq != leq
        args = leq.rhs.args
        assert len(args) == 2
        assert diff2sympy(args[0]) == 0.3 * hx
        assert args[1] == (hx * (dt + 0.4) * u + v).dx
Пример #10
0
    def test_solve(self):
        """
        By remaining unevaluated until after Operator's collect_derivatives,
        the Derivatives after a solve() should be collected.
        """
        grid = Grid(shape=(10, 10))

        u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)

        pde = u.dt2 - (u.dx.dx + u.dy.dy) - u.dx.dy
        eq = Eq(u.forward, solve(pde, u.forward))
        leq = collect_derivatives.func([eq])[0]

        assert len(eq.rhs.find(Derivative)) == 5
        assert len(leq.rhs.find(Derivative)) == 4
        assert len(
            leq.rhs.args[2].find(Derivative)) == 3  # Check factorization