def test_wilds_in_wilds(): from sympy import MatrixSymbol, MatMul A = MatrixSymbol('A', n, m) B = MatrixSymbol('B', m, k) pattern = patternify(A * B, 'A', n, m, B) # note that m is in B as well assert deconstruct(pattern) == Compound( MatMul, (Compound(MatrixSymbol, (Variable('A'), Variable(n), Variable(m))), Variable(B)))
def test_deconstruct(): expr = Basic(1, 2, 3) expected = Compound(Basic, (1, 2, 3)) assert deconstruct(expr) == expected assert deconstruct(1) == 1 assert deconstruct(x) == x assert deconstruct(x, variables=(x, )) == Variable(x) assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x)) assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \ Compound(Add, (1, Variable(x)))
def patternify(expr, *wilds, **kwargs): """ Create a matching pattern from an expression Example ======= >>> from sympy import symbols, sin, cos, Mul >>> from sympy.unify.usympy import patternify >>> a, b, c, x, y = symbols('a b c x y') >>> # Search for anything of the form sin(foo)**2 + cos(foo)**2 >>> pattern = patternify(sin(x)**2 + cos(x)**2, x) >>> # Search for any two things added to c. Note that here c is not a wild >>> pattern = patternify(a + b + c, a, b) >>> # Search for two things added together, one must be a Mul >>> pattern = patternify(a + b, a, b, types={a: Mul}) """ from sympy.rules.tools import subs types = kwargs.get('types', {}) vars = [CondVariable(wild, mk_matchtype(types[wild])) if wild in types else Variable(wild) for wild in wilds] if any(expr.has(cls) for cls in illegal): raise NotImplementedError("Unification not supported on type %s"%( type(s))) return subs(dict(zip(wilds, vars)))(expr)
def deconstruct(s): """ Turn a SymPy object into a Compound """ if isinstance(s, ExprWild): return Variable(s) if isinstance(s, (Variable, CondVariable)): return s if not isinstance(s, Basic) or s.is_Atom: return s return Compound(s.__class__, tuple(map(deconstruct, s.args)))
def deconstruct(s, variables=()): """ Turn a SymPy object into a Compound """ if s in variables: return Variable(s) if isinstance(s, (Variable, CondVariable)): return s if not isinstance(s, Basic) or s.is_Atom: return s return Compound(s.__class__, tuple(deconstruct(arg, variables) for arg in s.args))
def main(n): C1 = C('Add', [i for i in range(8)]) C2 = C('Add', [Variable(i) for i in range(8)]) for idx in range(n): lst = [] for elem in core.unify(C1, C2, {}): lst.append(elem) return lst
def test_CondVariable(): expr = C("CAdd", (1, 2)) x = Variable("x") y = CondVariable("y", lambda a: a % 2 == 0) z = CondVariable("z", lambda a: a > 3) pattern = C("CAdd", (x, y)) assert list(unify(expr, pattern, {})) == [{x: 1, y: 2}] z = CondVariable("z", lambda a: a > 3) pattern = C("CAdd", (z, y)) assert list(unify(expr, pattern, {})) == []
def test_CondVariable(): expr = C('CAdd', (1, 2)) x = Variable('x') y = CondVariable('y', lambda a: a % 2 == 0) z = CondVariable('z', lambda a: a > 3) pattern = C('CAdd', (x, y)) assert list(unify(expr, pattern, {})) == \ [{x: 1, y: 2}] z = CondVariable('z', lambda a: a > 3) pattern = C('CAdd', (z, y)) assert list(unify(expr, pattern, {})) == []
def test_defaultdict(): assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}
def test_defaultdict(): assert next(unify(Variable("x"), "foo")) == {Variable("x"): "foo"}
def test_patternify(): assert deconstruct(patternify(x + y, x)) in (Compound(Add, (Variable(x), y)), Compound(Add, (y, Variable(x)))) pattern = patternify(x**2 + y**2, x) assert list(unify(pattern, w**2 + y**2, {})) == [{x: w}]