Exemplo n.º 1
0
def test_wilds_in_wilds():
    from sympy import MatrixSymbol, MatMul
    A = MatrixSymbol('A', n, m)
    B = MatrixSymbol('B', m, k)
    pattern = patternify(A * B, 'A', n, m, B)  # note that m is in B as well
    assert deconstruct(pattern) == Compound(
        MatMul,
        (Compound(MatrixSymbol,
                  (Variable('A'), Variable(n), Variable(m))), Variable(B)))
def test_deconstruct():
    expr = Basic(1, 2, 3)
    expected = Compound(Basic, (1, 2, 3))
    assert deconstruct(expr) == expected

    assert deconstruct(1) == 1
    assert deconstruct(x) == x
    assert deconstruct(x, variables=(x, )) == Variable(x)
    assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
    assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
              Compound(Add, (1, Variable(x)))
Exemplo n.º 3
0
def patternify(expr, *wilds, **kwargs):
    """ Create a matching pattern from an expression

    Example
    =======

    >>> from sympy import symbols, sin, cos, Mul
    >>> from sympy.unify.usympy import patternify
    >>> a, b, c, x, y = symbols('a b c x y')

    >>> # Search for anything of the form sin(foo)**2 + cos(foo)**2
    >>> pattern = patternify(sin(x)**2 + cos(x)**2, x)

    >>> # Search for any two things added to c. Note that here c is not a wild
    >>> pattern = patternify(a + b + c, a, b)

    >>> # Search for two things added together, one must be a Mul
    >>> pattern = patternify(a + b, a, b, types={a: Mul})
    """
    from sympy.rules.tools import subs
    types = kwargs.get('types', {})
    vars = [CondVariable(wild, mk_matchtype(types[wild]))
                if wild in types else Variable(wild)
                for wild in wilds]
    if any(expr.has(cls) for cls in illegal):
        raise NotImplementedError("Unification not supported on type %s"%(
            type(s)))
    return subs(dict(zip(wilds, vars)))(expr)
Exemplo n.º 4
0
def deconstruct(s):
    """ Turn a SymPy object into a Compound """
    if isinstance(s, ExprWild):
        return Variable(s)
    if isinstance(s, (Variable, CondVariable)):
        return s
    if not isinstance(s, Basic) or s.is_Atom:
        return s
    return Compound(s.__class__, tuple(map(deconstruct, s.args)))
Exemplo n.º 5
0
def deconstruct(s, variables=()):
    """ Turn a SymPy object into a Compound """
    if s in variables:
        return Variable(s)
    if isinstance(s, (Variable, CondVariable)):
        return s
    if not isinstance(s, Basic) or s.is_Atom:
        return s
    return Compound(s.__class__,
                    tuple(deconstruct(arg, variables) for arg in s.args))
Exemplo n.º 6
0
def main(n):
    C1 = C('Add', [i for i in range(8)])
    C2 = C('Add', [Variable(i) for i in range(8)])

    for idx in range(n):
        lst = []
        for elem in core.unify(C1, C2, {}):
            lst.append(elem)

    return lst
Exemplo n.º 7
0
def test_CondVariable():
    expr = C("CAdd", (1, 2))
    x = Variable("x")
    y = CondVariable("y", lambda a: a % 2 == 0)
    z = CondVariable("z", lambda a: a > 3)
    pattern = C("CAdd", (x, y))
    assert list(unify(expr, pattern, {})) == [{x: 1, y: 2}]

    z = CondVariable("z", lambda a: a > 3)
    pattern = C("CAdd", (z, y))

    assert list(unify(expr, pattern, {})) == []
Exemplo n.º 8
0
def test_CondVariable():
    expr = C('CAdd', (1, 2))
    x = Variable('x')
    y = CondVariable('y', lambda a: a % 2 == 0)
    z = CondVariable('z', lambda a: a > 3)
    pattern = C('CAdd', (x, y))
    assert list(unify(expr, pattern, {})) == \
            [{x: 1, y: 2}]

    z = CondVariable('z', lambda a: a > 3)
    pattern = C('CAdd', (z, y))

    assert list(unify(expr, pattern, {})) == []
Exemplo n.º 9
0
def test_defaultdict():
    assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}
Exemplo n.º 10
0
def test_defaultdict():
    assert next(unify(Variable("x"), "foo")) == {Variable("x"): "foo"}
Exemplo n.º 11
0
def test_patternify():
    assert deconstruct(patternify(x + y,
                                  x)) in (Compound(Add, (Variable(x), y)),
                                          Compound(Add, (y, Variable(x))))
    pattern = patternify(x**2 + y**2, x)
    assert list(unify(pattern, w**2 + y**2, {})) == [{x: w}]