def test_numeric_constant(self): grid = Grid(shape=(10, 10)) u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) eq = Eq(u.forward, u.dx.dx + 0.3 * u.dy.dx) leq = collect_derivatives.func([eq])[0] assert len(leq.find(Derivative)) == 3
def test_nocollection_if_unworthy(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing u = TimeFunction(name="u", grid=grid) eq = Eq(u.forward, (0.4 + dt) * (u.dx + u.dy)) leq = collect_derivatives.func([eq])[0] assert eq == leq
def test_symbolic_constant(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) eq = Eq(u.forward, u.dx.dx + dt**0.2 * u.dy.dx) leq = collect_derivatives.func([eq])[0] assert len(leq.find(Derivative)) == 3
def test_nocollection_staggered(self): grid = Grid(shape=(10, 10)) x, y = grid.dimensions u = TimeFunction(name="u", grid=grid) v = TimeFunction(name="v", grid=grid, staggered=x) eq = Eq(u.forward, u.dx + v.dx) leq = collect_derivatives.func([eq])[0] assert eq == leq
def test_nocollection_subdims(self): grid = Grid(shape=(10, 10)) xi, yi = grid.interior.dimensions u = TimeFunction(name="u", grid=grid) v = TimeFunction(name="v", grid=grid) f = Function(name='f', grid=grid) eq = Eq(u.forward, u.dx + 0.2 * f[xi, yi] * v.dx) leq = collect_derivatives.func([eq])[0] assert eq == leq
def test_symbolic_constant_times_add(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) f = Function(name='f', grid=grid) eq = Eq(u.forward, u.laplace + dt**0.2 * u.biharmonic(1 / f)) leq = collect_derivatives.func([eq])[0] assert len(eq.rhs.args) == 3 assert len(leq.rhs.args) == 2 assert all(isinstance(i, Derivative) for i in leq.rhs.args)
def test_pull_and_collect_nested_v3(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing hx, hy = grid.spacing_symbols a = Function(name="a", grid=grid, space_order=2) u = TimeFunction(name="u", grid=grid, space_order=2) v = TimeFunction(name="v", grid=grid, space_order=2) eq = Eq(u.forward, 0.4 + a * (hx + dt * (u.dx + v.dx))) leq = collect_derivatives.func([eq])[0] assert eq != leq assert leq.rhs == 0.4 + a * (hx + (dt * u + dt * v).dx)
def test_pull_and_collect_nested_v2(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing hx, hy = grid.spacing_symbols u = TimeFunction(name="u", grid=grid, space_order=2) v = TimeFunction(name="v", grid=grid, space_order=2) eq = Eq(u.forward, ((0.4 + dt * (hy + 1. + hx * hy)) * u.dx + 0.3) * hx + v.dx) leq = collect_derivatives.func([eq])[0] assert eq != leq assert leq.rhs == 0.3 * hx + (hx * (0.4 + dt * (hy + 1. + hx * hy)) * u + v).dx
def test_pull_and_collect(self): grid = Grid(shape=(10, 10)) dt = grid.time_dim.spacing hx, _ = grid.spacing_symbols u = TimeFunction(name="u", grid=grid) v = TimeFunction(name="v", grid=grid) eq = Eq(u.forward, ((0.4 + dt) * u.dx + 0.3) * hx + v.dx) leq = collect_derivatives.func([eq])[0] assert eq != leq args = leq.rhs.args assert len(args) == 2 assert diff2sympy(args[0]) == 0.3 * hx assert args[1] == (hx * (dt + 0.4) * u + v).dx
def test_solve(self): """ By remaining unevaluated until after Operator's collect_derivatives, the Derivatives after a solve() should be collected. """ grid = Grid(shape=(10, 10)) u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2) pde = u.dt2 - (u.dx.dx + u.dy.dy) - u.dx.dy eq = Eq(u.forward, solve(pde, u.forward)) leq = collect_derivatives.func([eq])[0] assert len(eq.rhs.find(Derivative)) == 5 assert len(leq.rhs.find(Derivative)) == 4 assert len( leq.rhs.args[2].find(Derivative)) == 3 # Check factorization