Пример #1
0
def test_finditer():
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'})
    c_local = Create(hs=h1_custom)

    expr = 2 * (a * b * c - b * c * a + a * b)
    pat = wc('sym', head=OperatorSymbol)
    for m in pat.finditer(expr):
        assert 'sym' in m
    matches = list(pat.finditer(expr))
    assert len(matches) == 8
    op_symbols = [m['sym'] for m in matches]
    assert set(op_symbols) == {a, b, c}

    op = wc(head=Operator)
    three_factors = pattern(OperatorTimes, op, op, op).findall(expr)
    assert three_factors == [a * b * c, b * c * a]
    assert len(list(pattern(LocalOperator).finditer(expr))) == 0
    assert (
        len(
            list(
                pattern(LocalOperator).finditer(expr.substitute({c: c_local}))
            )
        )
        == 2
    )
Пример #2
0
def test_wc_names():
    """Test the wc_names property"""
    ra = wc("ra", head=(int, str))
    rb = wc("rb", head=(int, str))
    rc = wc("rc", head=(int, str))
    rd = wc("rd", head=(int, str))
    ls = wc("ls", head=LocalSpace)
    pat = pattern_head(
        pattern(LocalSigma, ra, rb, hs=ls), pattern(LocalSigma, rc, rd, hs=ls)
    )
    assert pat.wc_names == set(['ra', 'rb', 'rc', 'rd', 'ls'])
Пример #3
0
def test_no_match():
    """Test that matches fail for the correct reason"""

    conds = [lambda i: i > 0, lambda i: i < 10]
    match = wc('i__', head=int, conditions=conds).match(10)
    assert not match
    assert 'does not meet condition 2' in match.reason

    pat = pattern_head(pattern(int), pattern(int), wc('i___', head=int))
    match = pat.match(ProtoExpr([1], {}))
    assert not match
    assert 'insufficient number of arguments' in match.reason

    pat = pattern_head(1, 2, 3)
    match = pat.match(ProtoExpr([1, 2], {}))
    assert not match
    assert 'insufficient number of arguments' in match.reason

    pat = pattern_head(pattern(int), wc('i__', head=int))
    match = pat.match(ProtoExpr([1], {}))
    assert not match
    assert 'insufficient number of arguments' in match.reason

    pat = pattern_head(a=pattern(int), b=pattern(int))
    match = pat.match(ProtoExpr([], {'a': 1, 'c': 2}))
    assert not match
    assert "has no keyword argument 'b'" in match.reason

    pat = pattern_head(a=pattern(int), b=pattern(str))
    match = pat.match(ProtoExpr([], {'a': 1, 'b': 2}))
    assert not match
    assert "2 is not an instance of str" in match.reason

    pat = pattern_head(
        a=pattern(int), b=pattern_head(pattern(int), pattern(int))
    )
    match = pat.match(ProtoExpr([], {'a': 1, 'b': 2}))
    assert not match
    assert "2 is not an instance of ProtoExpr" in match.reason

    pat = pattern_head(pattern(int))
    match = pat.match(ProtoExpr([1, 2], {}))
    assert not match
    assert 'too many positional arguments' in match.reason

    match = match_pattern(1, 2)
    assert not match.success
    assert "Expressions '1' and '2' are not the same" in match.reason
Пример #4
0
def test_simplify():
    """Test simplification of expr according to manual rules"""
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    d = OperatorSymbol("d", hs=h1)

    expr = 2 * (a * b * c - b * c * a)

    A_ = wc('A', head=Operator)
    B_ = wc('B', head=Operator)
    C_ = wc('C', head=Operator)

    def b_times_c_equal_d(B, C):
        if B.label == 'b' and C.label == 'c':
            return d
        else:
            raise CannotSimplify

    with temporary_rules(OperatorTimes):
        OperatorTimes.add_rule('extra', pattern_head(B_, C_),
                               b_times_c_equal_d)
        new_expr = expr.rebuild()

    commutator_rule = (
        pattern(
            OperatorPlus,
            pattern(OperatorTimes, A_, B_),
            pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)),
        ),
        lambda A, B: OperatorSymbol("Commut%s%s" %
                                    (A.label.upper(), B.label.upper()),
                                    hs=A.space),
    )
    assert commutator_rule[0].match(new_expr.term)

    with temporary_rules(OperatorTimes):
        OperatorTimes.add_rule('extra', pattern_head(B_, C_),
                               b_times_c_equal_d)
        new_expr = _apply_rules(expr, [commutator_rule])
    assert (srepr(new_expr) ==
            "ScalarTimesOperator(ScalarValue(2), OperatorSymbol('CommutAD', "
            "hs=LocalSpace('h1')))")
Пример #5
0
def test_extra_binary_rules():
    """Test creation of expr with extra binary rules"""
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    A_ = wc('A', head=Operator)
    B_ = wc('B', head=Operator)
    rule = (
        pattern_head(
            pattern(OperatorTimes, A_, B_),
            pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)),
        ),
        lambda A, B: c,
    )
    with temporary_rules(OperatorPlus):
        OperatorPlus.add_rule('extra', rule[0], rule[1])
        assert ('extra', rule) in OperatorPlus._binary_rules.items()
        expr = 2 * (a * b - b * a + IdentityOperator)
        assert expr == 2 * (c + IdentityOperator)
    assert rule not in OperatorPlus._binary_rules.values()
Пример #6
0
def test_invalid_pattern():
    """Test that instantiating a Pattern with invalid attributes raises the
    appropriate exceptions"""
    with pytest.raises(TypeError) as exc_info:
        Pattern(head='OperatorSymbol')
    assert 'must be class' in str(exc_info)
    with pytest.raises(ValueError) as exc_info:
        pattern(ScalarTimesOperator, wc('a'), wc('b__'), wc('c'))
    assert (
        'Only the first or last argument may have a mode indicating an '
        'occurrence of more than 1' in str(exc_info)
    )
    with pytest.raises(ValueError) as exc_info:
        wc('a_____')
    assert "Invalid name_mode" in str(exc_info)
    with pytest.raises(ValueError) as exc_info:
        pattern(ScalarTimesOperator, wc('a'), wc('b'), wc_name='S', mode=5)
    assert "Mode must be one of" in str(exc_info)
    with pytest.raises(ValueError) as exc_info:
        pattern(ScalarTimesOperator, wc('a'), wc('b'), wc_name='S', mode='1')
    assert "Mode must be one of" in str(exc_info)
Пример #7
0
def test_findall():
    h1 = LocalSpace("h1")
    a = OperatorSymbol("a", hs=h1)
    b = OperatorSymbol("b", hs=h1)
    c = OperatorSymbol("c", hs=h1)
    h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'})
    c_local = Create(hs=h1_custom)

    expr = 2 * (a * b * c - b * c * a + a * b)
    op_symbols = pattern(OperatorSymbol).findall(expr)
    assert len(op_symbols) == 8
    assert set(op_symbols) == {a, b, c}
    op = wc(head=Operator)
    three_factors = pattern(OperatorTimes, op, op, op).findall(expr)
    assert three_factors == [a * b * c, b * c * a]
    assert len(pattern(LocalOperator).findall(expr)) == 0
    assert (
        len(pattern(LocalOperator).findall(expr.substitute({c: c_local}))) == 2
    )
    def testMatch(self):

        A = wc("A", head=SuperOperator)
        a = SuperOperatorSymbol("a", hs="hs")
        b = SuperOperatorSymbol("b", hs="hs")
        b2 = SuperOperatorSymbol("b", hs="hs")

        assert b == b2
        assert A.match(a)
        assert A.match(a)['A'] == a
        assert A.match(b)
        assert A.match(b)['A'] == b

        expr = ProtoExpr(args=[b, b], kwargs={})
        pat = pattern_head(A, A)
        assert pat.match(expr)
        assert pat.match(expr)['A'] == b

        expr = ProtoExpr(args=[b, b2], kwargs={})
        pat = pattern_head(A, A)
        assert pat.match(expr)
        assert pat.match(expr)['A'] == b
Пример #9
0
def test_wc():
    """Test that the wc() constructor produces the equivalent Pattern
    instance"""
    patterns = [
        (
            wc(),
            Pattern(
                head=None,
                args=None,
                kwargs=None,
                mode=Pattern.single,
                wc_name=None,
                conditions=None,
            ),
        ),
        (
            wc('a'),
            Pattern(
                head=None,
                args=None,
                kwargs=None,
                mode=Pattern.single,
                wc_name='a',
                conditions=None,
            ),
        ),
        (wc('a_'), wc('a')),
        (
            wc('a__'),
            Pattern(
                head=None,
                args=None,
                kwargs=None,
                mode=Pattern.one_or_more,
                wc_name='a',
                conditions=None,
            ),
        ),
        (
            wc('a___'),
            Pattern(
                head=None,
                args=None,
                kwargs=None,
                mode=Pattern.zero_or_more,
                wc_name='a',
                conditions=None,
            ),
        ),
        (
            wc('a', head=int),
            Pattern(
                head=int,
                args=None,
                kwargs=None,
                mode=Pattern.single,
                wc_name='a',
                conditions=None,
            ),
        ),
        (
            wc('a', head=(int, float)),
            Pattern(
                head=(int, float),
                args=None,
                kwargs=None,
                mode=Pattern.single,
                wc_name='a',
                conditions=None,
            ),
        ),
    ]
    for pat1, pat2 in patterns:
        print(repr(pat1))
        assert pat1 == pat2
    with pytest.raises(ValueError):
        wc("____")
Пример #10
0
        assert pat1 == pat2


# test expressions
two_t = 2 * Symbol('t')
two_O = 2 * OperatorSymbol('O', hs=FullSpace)
proto_two_O = ProtoExpr([2, OperatorSymbol('O', hs=FullSpace)], {})
proto_kwargs = ProtoExpr([1, 2], {'a': '3', 'b': 4})
proto_kw_only = ProtoExpr([], {'a': 1, 'b': 2})
proto_ints2 = ProtoExpr([1, 2], {})
proto_ints3 = ProtoExpr([1, 2, 3], {})
proto_ints4 = ProtoExpr([1, 2, 3, 4], {})
proto_ints5 = ProtoExpr([1, 2, 3, 4, 5], {})

# test patterns and wildcards
wc_a_int_2 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 2])
wc_a_int_3 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 3])
wc_a_int = wc('a', head=int)
wc_label_str = wc('label', head=str)
wc_hs = wc('space', head=HilbertSpace)
pattern_two_O = pattern(
    ScalarTimesOperator,
    wc_a_int_2,
    pattern(OperatorSymbol, wc_label_str, hs=wc_hs),
)
pattern_two_O_head = pattern_head(
    wc_a_int_2, pattern(OperatorSymbol, wc_label_str, hs=wc_hs)
)
pattern_two_O_expr = pattern(
    ScalarTimesOperator, wc_a_int_2, OperatorSymbol('O', hs=FullSpace)
)