def test_match_modulo_identity(): a, b, c = declare("a", "b", "c") from dagrt.expression import match subst = match(c * a + b * a, c * a + a, ["b"]) assert subst["b"] == 1 subst = match((c + a) * (b + a), (c + a) * a, ["b"]) assert subst["b"] == 0
def test_match_with_pre_match(): a, b, c, d = declare("a", "b", "c", "d") from dagrt.expression import match subst = match(a + b, c + d, ["a", "b"], pre_match={"a": "c"}) assert subst["a"] == c assert subst["b"] == d
def test_match_strings(): from dagrt.expression import match from pymbolic import var subst = match("a+b*a", "a+b*a") assert len(subst) == 2 assert subst["a"] == var("a") assert subst["b"] == var("b")
def solver_hook(solve_expr, unknown, solver_id, guess): from dagrt.expression import match, substitute pieces = match( "k - <func>rhs(time, y + dt * (c0 + c1 * k))", solve_expr, pre_match={"k": unknown}) return substitute("-10 * (dt * c0 + y) / (10 * dt * c1 + 1)", pieces)
def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match, substitute pieces = match("unk - <func>rhs(t=t, y=sub_y + coeff*unk)", solve_expr, pre_match={"unk": solve_var}) pieces["guess"] = guess return substitute("<func>solver(t, sub_y, coeff, guess)", pieces)
def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match, substitute pieces = match("unk + (-1)*<func>impl_y(y=sub_y + coeff*unk, t=t)", solve_expr, pre_match={"unk": solve_var}) pieces["guess"] = guess return substitute("<func>solver(sub_y, coeff)", pieces)
def am_solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match, substitute pieces = match( "unk + (-1)*<func>f(t=t, fast=sub_fast + coeff*unk, " "slow=sub_slow)", solve_expr, pre_match={"unk": solve_var}) pieces["guess"] = guess return substitute("<func>solver(sub_fast, sub_slow, " "coeff, t)", pieces)
def solver_hook(solve_expr, solve_var, solver_id, guess): from dagrt.expression import match, substitute pieces = match("unk - <func>rhs(t=t, y=<state>y + sub_y + coeff*unk)", solve_expr, bound_variable_names=["<state>y"], pre_match={"unk": solve_var}) pieces["guess"] = guess return substitute("<func>solver(t, <state>y, sub_y, coeff)", pieces)
def test_match(): f, y, h, t, yy, hh, tt = declare("f", "y", "h", "t", "yy", "hh", "tt") lhs = y - h * f(t, y) rhs = -hh * f(tt, yy) + yy from dagrt.expression import match subst = match(lhs, rhs, ["t", "h", "y"]) assert len(subst) == 3 assert subst["h"] == hh assert subst["t"] == tt assert subst["y"] == yy
def test_match_functions(): lhsvars = ["f", "u", "s", "c", "t"] f, u, s, c, t = declare(*lhsvars) ff, uu, ss, cc, tt = declare("ff", "uu", "ss", "cc", "tt") rhsvars = [ff, uu, ss, cc, tt] lhs = u - f(t=t, y=s + c * u) rhs = uu - ff(t=tt, y=ss + cc * uu) from dagrt.expression import match subst = match(lhs, rhs, lhsvars) assert len(subst) == len(lhsvars) for var, matchval in zip(lhsvars, rhsvars): assert subst[var] == matchval
def test_match_with_pre_match_invalid_arg(): a, b, c, d = declare("a", "b", "c", "d") from dagrt.expression import match with pytest.raises(ValueError): match(a + b, c + d, ["a"], pre_match={"b": "c"})
def solver_hook(expr, var, solver_id, guess): from dagrt.expression import match, substitute pieces = match("unk-y-h*<func>f(t=t,y=unk)", expr, pre_match={"unk": var}) pieces["guess"] = guess return substitute("<func>solver(t,h,y,guess)", pieces)