def test_wilds_in_wilds(): from sympy import MatrixSymbol, MatMul A = MatrixSymbol('A', n, m) B = MatrixSymbol('B', m, k) pattern = patternify(A * B, 'A', n, m, B) # note that m is in B as well assert deconstruct(pattern) == Compound( MatMul, (Compound(MatrixSymbol, (Variable('A'), Variable(n), Variable(m))), Variable(B)))
def test_deconstruct(): expr = Basic(1, 2, 3) expected = Compound(Basic, (1, 2, 3)) assert deconstruct(expr) == expected assert deconstruct(1) == 1 assert deconstruct(x) == x assert deconstruct(x, variables=(x, )) == Variable(x) assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x)) assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \ Compound(Add, (1, Variable(x)))
def deconstruct(s): """ Turn a SymPy object into a Compound """ if isinstance(s, ExprWild): return Variable(s) if isinstance(s, (Variable, CondVariable)): return s if not isinstance(s, Basic) or s.is_Atom: return s return Compound(s.__class__, tuple(map(deconstruct, s.args)))
def deconstruct(s, variables=()): """ Turn a SymPy object into a Compound """ if s in variables: return Variable(s) if isinstance(s, (Variable, CondVariable)): return s if not isinstance(s, Basic) or s.is_Atom: return s return Compound(s.__class__, tuple(deconstruct(arg, variables) for arg in s.args))
def test_nested(): expr = Basic(1, Basic(2), 3) cmpd = Compound(Basic, (1, Compound(Basic, (2, )), 3)) assert deconstruct(expr) == cmpd assert construct(cmpd) == expr
def test_construct(): expr = Compound(Basic, (1, 2, 3)) expected = Basic(1, 2, 3) assert construct(expr) == expected
def test_commutativity(): c1 = Compound('CAdd', (a, b)) c2 = Compound('CAdd', (x, y)) assert is_commutative(c1) and is_commutative(c2) assert len(list(unify(c1, c2, {}))) == 2
def test_nested(): expr = Basic(S(1), Basic(S(2)), S(3)) cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3))) assert deconstruct(expr) == cmpd assert construct(cmpd) == expr
def test_construct(): expr = Compound(Basic, (S(1), S(2), S(3))) expected = Basic(S(1), S(2), S(3)) assert construct(expr) == expected
def test_deconstruct(): expr = Basic(1, 2, 3) expected = Compound(Basic, (1, 2, 3)) assert deconstruct(expr) == expected
def test_patternify(): assert deconstruct(patternify(x + y, x)) in (Compound(Add, (Variable(x), y)), Compound(Add, (y, Variable(x)))) pattern = patternify(x**2 + y**2, x) assert list(unify(pattern, w**2 + y**2, {})) == [{x: w}]