def test_unify(): expr = Basic(1, 2, 3) a, b, c = map(Symbol, 'abc') pattern = Basic(a, b, c) assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}] assert list(unify(expr, pattern, variables=(a, b, c))) == \ [{a: 1, b: 2, c: 3}]
def test_matrix(): from sympy.matrices.expressions.matexpr import MatrixSymbol X = MatrixSymbol('X', n, n) Y = MatrixSymbol('Y', 2, 2) Z = MatrixSymbol('Z', 2, 3) assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}] assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
def test_matrix(): from sympy import MatrixSymbol X = MatrixSymbol('X', n, n) Y = MatrixSymbol('Y', 2, 2) Z = MatrixSymbol('Z', 2, 3) assert list(unify(X, Y, {}, variables=[n, 'X'])) == [{'X': 'Y', n: 2}] assert list(unify(X, Z, {}, variables=[n, 'X'])) == []
def test_matrix(): from sympy import MatrixSymbol X = MatrixSymbol('X', n, n) Y = MatrixSymbol('Y', 2, 2) Z = MatrixSymbol('Z', 2, 3) p = patternify(X, 'X', n) assert list(unify(p, Y, {})) == [{'X': 'Y', n: 2}] assert list(unify(p, Z, {})) == []
def test_matrix(): from sympy import MatrixSymbol X = MatrixSymbol("X", n, n) Y = MatrixSymbol("Y", 2, 2) Z = MatrixSymbol("Z", 2, 3) assert list(unify(X, Y, {}, variables=[n, Symbol("X")])) == [ {Symbol("X"): Symbol("Y"), n: 2} ] assert list(unify(X, Z, {}, variables=[n, Symbol("X")])) == []
def test_unify_commutative(): expr = Add(1, 2, 3, evaluate=False) a, b, c = map(Symbol, 'abc') pattern = Add(a, b, c, evaluate=False) result = tuple(unify(expr, pattern, {}, (a, b, c))) expected = ({ a: 1, b: 2, c: 3 }, { a: 1, b: 3, c: 2 }, { a: 2, b: 1, c: 3 }, { a: 2, b: 3, c: 1 }, { a: 3, b: 1, c: 2 }, { a: 3, b: 2, c: 1 }) assert iterdicteq(result, expected)
def test_FiniteSet_commutivity(): from sympy import FiniteSet a, b, c, x, y = symbols('a,b,c,x,y') s = FiniteSet(a, b, c) t = FiniteSet(x, y) pattern = patternify(t, x, y) assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, pattern))
def test_hard_match(): from sympy import sin, cos expr = sin(x) + cos(x) ** 2 p, q = map(Symbol, "pq") pattern = sin(p) + cos(p) ** 2 assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
def test_FiniteSet_commutivity(): from sympy import FiniteSet a, b, c, x, y = symbols('a,b,c,x,y') s = FiniteSet(a, b, c) t = FiniteSet(x, y) variables = (x, y) assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
def test_commutative_in_commutative(): from sympy.abc import a, b, c, d from sympy import sin, cos eq = sin(3) * sin(4) * sin(5) + 4 * cos(3) * cos(4) pat = a * cos(b) * cos(c) + d * sin(b) * sin(c) assert next(unify(eq, pat, variables=(a, b, c, d)))
def test_FiniteSet_complex(): from sympy import FiniteSet a, b, c, x, y, z = symbols('a,b,c,x,y,z') expr = FiniteSet(Basic(1, x), y, Basic(x, z)) expected = tuple([{b: 1, a: FiniteSet(y, Basic(x, z))}, {b: z, a: FiniteSet(y, Basic(1, x))}]) pattern = patternify(FiniteSet(a, Basic(x, b)), a, b) assert iterdicteq(unify(expr, pattern), expected)
def test_FiniteSet_complex(): from sympy import FiniteSet a, b, c, x, y, z = symbols('a,b,c,x,y,z') expr = FiniteSet(Basic(1, x), y, Basic(x, z)) pattern = FiniteSet(a, Basic(x, b)) variables = a, b expected = tuple([{b: 1, a: FiniteSet(y, Basic(x, z))}, {b: z, a: FiniteSet(y, Basic(1, x))}]) assert iterdicteq(unify(expr, pattern, variables=variables), expected)
def test_Union(): from sympy import Interval assert list( unify( Interval(0, 1) + Interval(10, 11), Interval(0, 1) + Interval(12, 13), variables=(Interval(12, 13),), ) )
def rewrite_rl(expr, assumptions=True): for match in unify(source, expr, {}, variables=variables): if condition and not condition(*[match.get(var, var) for var in variables]): continue if assume and not ask(assume.xreplace(match), assumptions): continue expr2 = subs(match)(target) if isinstance(expr2, Expr): expr2 = rebuild(expr2) yield expr2
def rewrite_rl(expr, assumptions=True): for match in unify(source, expr, {}, variables=variables): if (condition and not condition(*[match.get(var, var) for var in variables])): continue if (assume and not ask(assume.xreplace(match), assumptions)): continue expr2 = subs(match)(target) if isinstance(expr2, Expr): expr2 = rebuild(expr2) yield expr2
def test_unify_commutative(): expr = Add(1, 2, 3, evaluate=False) a, b, c = map(Symbol, 'abc') pattern = Add(a, b, c, evaluate=False) result = tuple(unify(expr, pattern, {}, (a, b, c))) expected = ({a: 1, b: 2, c: 3}, {a: 1, b: 3, c: 2}, {a: 2, b: 1, c: 3}, {a: 2, b: 3, c: 1}, {a: 3, b: 1, c: 2}, {a: 3, b: 2, c: 1}) assert iterdicteq(result, expected)
def test_unify_iter(): expr = Add(1, 2, 3, evaluate=False) a, b, c = map(Symbol, 'abc') pattern = Add(a, c, evaluate=False) assert is_associative(deconstruct(pattern)) assert is_commutative(deconstruct(pattern)) result = list(unify(expr, pattern, {}, (a, c))) expected = [{ a: 1, c: Add(2, 3, evaluate=False) }, { a: 1, c: Add(3, 2, evaluate=False) }, { a: 2, c: Add(1, 3, evaluate=False) }, { a: 2, c: Add(3, 1, evaluate=False) }, { a: 3, c: Add(1, 2, evaluate=False) }, { a: 3, c: Add(2, 1, evaluate=False) }, { a: Add(1, 2, evaluate=False), c: 3 }, { a: Add(2, 1, evaluate=False), c: 3 }, { a: Add(1, 3, evaluate=False), c: 2 }, { a: Add(3, 1, evaluate=False), c: 2 }, { a: Add(2, 3, evaluate=False), c: 1 }, { a: Add(3, 2, evaluate=False), c: 1 }] assert iterdicteq(result, expected)
def test_unify_iter(): expr = Add(1, 2, 3, evaluate=False) a, b, c = map(Symbol, 'abc') pattern = Add(a, c, evaluate=False) assert is_associative(deconstruct(pattern)) assert is_commutative(deconstruct(pattern)) result = list(unify(expr, pattern, {}, (a, c))) expected = [{a: 1, c: Add(2, 3, evaluate=False)}, {a: 1, c: Add(3, 2, evaluate=False)}, {a: 2, c: Add(1, 3, evaluate=False)}, {a: 2, c: Add(3, 1, evaluate=False)}, {a: 3, c: Add(1, 2, evaluate=False)}, {a: 3, c: Add(2, 1, evaluate=False)}, {a: Add(1, 2, evaluate=False), c: 3}, {a: Add(2, 1, evaluate=False), c: 3}, {a: Add(1, 3, evaluate=False), c: 2}, {a: Add(3, 1, evaluate=False), c: 2}, {a: Add(2, 3, evaluate=False), c: 1}, {a: Add(3, 2, evaluate=False), c: 1}] assert iterdicteq(result, expected)
def test_s_input(): expr = Basic(1, 2) a, b = map(Symbol, 'ab') pattern = Basic(a, b) assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}] assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
def test_unify_variables(): assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x, ))) == [{ x: 2 }]
def test_patternify_with_types(): a, b, c, x, y = symbols('a,b,c,x,y') pattern = patternify(x + y, x, y, types={x: Mul}) expr = a * b + c assert list(unify(expr, pattern)) == [{x: a * b, y: c}]
def test_unify_variables(): assert list(unify(Basic(1, 2), Basic(1, x), {}, variables=(x, ))) == [{ x: 2 }]
def test_and(): variables = x, y str(list(unify((x > 0) & (z < 3), pattern, variables=variables)))
def test_and(): variables = x, y expected = tuple([{x: z > 0, y: n < 3}]) assert iterdicteq(unify((z > 0) & (n < 3), And(x, y), variables=variables), expected)
def test_hard_match(): from sympy import sin, cos expr = sin(x) + cos(x)**2 p, q = map(Symbol, 'pq') pattern = sin(p) + cos(p)**2 assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
def test_unify(): expr = Basic(1, 2, 3) a, b, c = map(ExprWild, 'abc') pattern = Basic(a, b, c) assert list(unify(expr, pattern, {})) == [{a: 1, b: 2, c: 3}] assert list(unify(expr, pattern)) == [{a: 1, b: 2, c: 3}]
def test_s_input(): expr = Basic(1, 2) a, b = map(ExprWild, 'ab') pattern = Basic(a, b) assert list(unify(expr, pattern, {})) == [{a: 1, b: 2}] assert list(unify(expr, pattern, {a: 5})) == []
def test_patternify(): assert deconstruct(patternify(x + y, x)) in ( Compound(Add, (Variable(x), y)), Compound(Add, (y, Variable(x)))) pattern = patternify(x**2 + y**2, x) assert list(unify(pattern, w**2 + y**2, {})) == [{x: w}]
def test_hard_match(): from sympy.functions.elementary.trigonometric import (cos, sin) expr = sin(x) + cos(x)**2 p, q = map(Symbol, 'pq') pattern = sin(p) + cos(p)**2 assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
def test_patternify(): assert deconstruct(patternify(x + y, x)) in (Compound(Add, (Variable(x), y)), Compound(Add, (y, Variable(x)))) pattern = patternify(x**2 + y**2, x) assert list(unify(pattern, w**2 + y**2, {})) == [{x: w}]
def rewrite_rl(expr): for match in unify(p1, expr, {}, variables=variables): expr2 = subs(match)(p2) if isinstance(expr2, Expr): expr2 = rebuild(expr2) yield expr2
def test_and(): pattern = patternify(And(x, y), x, y) str(list(unify((x>0) & (z<3), pattern)))
def test_and(): variables = x, y str(list(unify((x>0) & (z<3), pattern, variables=variables)))
def test_Union(): from sympy import Interval assert list(unify(Interval(0, 1) + Interval(10, 11), Interval(0, 1) + Interval(12, 13), variables=(Interval(12, 13),)))
def test_commutative_in_commutative(): from sympy.abc import a,b,c,d from sympy import sin, cos eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4) pat = a*cos(b)*cos(c) + d*sin(b)*sin(c) assert next(unify(eq, pat, variables=(a,b,c,d)))
def test_and(): pattern = patternify(And(x, y), x, y) str(list(unify((x > 0) & (z < 3), pattern)))
def test_patternify_with_types(): a, b, c, x, y = symbols('a,b,c,x,y') pattern = patternify(x + y, x, y, types={x: Mul}) expr = a*b + c assert list(unify(expr, pattern)) == [{x: a*b, y: c}]
def test_unify_variables(): assert list(unify(Basic(1, 2), Basic(1, x), {}, variables=(x,))) == [{x: 2}]
def test_hard_match(): from sympy import sin, cos expr = sin(x) + cos(x)**2 p, q = map(ExprWild, 'pq') pattern = sin(p) + cos(p)**2 assert list(unify(expr, pattern, {})) == [{p: x}]
def test_commutative_in_commutative(): from sympy.abc import a, b, c, d from sympy.functions.elementary.trigonometric import (cos, sin) eq = sin(3) * sin(4) * sin(5) + 4 * cos(3) * cos(4) pat = a * cos(b) * cos(c) + d * sin(b) * sin(c) assert next(unify(eq, pat, variables=(a, b, c, d)))