def test_finditer(): h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'}) c_local = Create(hs=h1_custom) expr = 2 * (a * b * c - b * c * a + a * b) pat = wc('sym', head=OperatorSymbol) for m in pat.finditer(expr): assert 'sym' in m matches = list(pat.finditer(expr)) assert len(matches) == 8 op_symbols = [m['sym'] for m in matches] assert set(op_symbols) == {a, b, c} op = wc(head=Operator) three_factors = pattern(OperatorTimes, op, op, op).findall(expr) assert three_factors == [a * b * c, b * c * a] assert len(list(pattern(LocalOperator).finditer(expr))) == 0 assert ( len( list( pattern(LocalOperator).finditer(expr.substitute({c: c_local})) ) ) == 2 )
def test_wc_names(): """Test the wc_names property""" ra = wc("ra", head=(int, str)) rb = wc("rb", head=(int, str)) rc = wc("rc", head=(int, str)) rd = wc("rd", head=(int, str)) ls = wc("ls", head=LocalSpace) pat = pattern_head( pattern(LocalSigma, ra, rb, hs=ls), pattern(LocalSigma, rc, rd, hs=ls) ) assert pat.wc_names == set(['ra', 'rb', 'rc', 'rd', 'ls'])
def test_no_match(): """Test that matches fail for the correct reason""" conds = [lambda i: i > 0, lambda i: i < 10] match = wc('i__', head=int, conditions=conds).match(10) assert not match assert 'does not meet condition 2' in match.reason pat = pattern_head(pattern(int), pattern(int), wc('i___', head=int)) match = pat.match(ProtoExpr([1], {})) assert not match assert 'insufficient number of arguments' in match.reason pat = pattern_head(1, 2, 3) match = pat.match(ProtoExpr([1, 2], {})) assert not match assert 'insufficient number of arguments' in match.reason pat = pattern_head(pattern(int), wc('i__', head=int)) match = pat.match(ProtoExpr([1], {})) assert not match assert 'insufficient number of arguments' in match.reason pat = pattern_head(a=pattern(int), b=pattern(int)) match = pat.match(ProtoExpr([], {'a': 1, 'c': 2})) assert not match assert "has no keyword argument 'b'" in match.reason pat = pattern_head(a=pattern(int), b=pattern(str)) match = pat.match(ProtoExpr([], {'a': 1, 'b': 2})) assert not match assert "2 is not an instance of str" in match.reason pat = pattern_head( a=pattern(int), b=pattern_head(pattern(int), pattern(int)) ) match = pat.match(ProtoExpr([], {'a': 1, 'b': 2})) assert not match assert "2 is not an instance of ProtoExpr" in match.reason pat = pattern_head(pattern(int)) match = pat.match(ProtoExpr([1, 2], {})) assert not match assert 'too many positional arguments' in match.reason match = match_pattern(1, 2) assert not match.success assert "Expressions '1' and '2' are not the same" in match.reason
def test_simplify(): """Test simplification of expr according to manual rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) d = OperatorSymbol("d", hs=h1) expr = 2 * (a * b * c - b * c * a) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) C_ = wc('C', head=Operator) def b_times_c_equal_d(B, C): if B.label == 'b' and C.label == 'c': return d else: raise CannotSimplify with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = expr.rebuild() commutator_rule = ( pattern( OperatorPlus, pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: OperatorSymbol("Commut%s%s" % (A.label.upper(), B.label.upper()), hs=A.space), ) assert commutator_rule[0].match(new_expr.term) with temporary_rules(OperatorTimes): OperatorTimes.add_rule('extra', pattern_head(B_, C_), b_times_c_equal_d) new_expr = _apply_rules(expr, [commutator_rule]) assert (srepr(new_expr) == "ScalarTimesOperator(ScalarValue(2), OperatorSymbol('CommutAD', " "hs=LocalSpace('h1')))")
def test_extra_binary_rules(): """Test creation of expr with extra binary rules""" h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) A_ = wc('A', head=Operator) B_ = wc('B', head=Operator) rule = ( pattern_head( pattern(OperatorTimes, A_, B_), pattern(ScalarTimesOperator, -1, pattern(OperatorTimes, B_, A_)), ), lambda A, B: c, ) with temporary_rules(OperatorPlus): OperatorPlus.add_rule('extra', rule[0], rule[1]) assert ('extra', rule) in OperatorPlus._binary_rules.items() expr = 2 * (a * b - b * a + IdentityOperator) assert expr == 2 * (c + IdentityOperator) assert rule not in OperatorPlus._binary_rules.values()
def test_invalid_pattern(): """Test that instantiating a Pattern with invalid attributes raises the appropriate exceptions""" with pytest.raises(TypeError) as exc_info: Pattern(head='OperatorSymbol') assert 'must be class' in str(exc_info) with pytest.raises(ValueError) as exc_info: pattern(ScalarTimesOperator, wc('a'), wc('b__'), wc('c')) assert ( 'Only the first or last argument may have a mode indicating an ' 'occurrence of more than 1' in str(exc_info) ) with pytest.raises(ValueError) as exc_info: wc('a_____') assert "Invalid name_mode" in str(exc_info) with pytest.raises(ValueError) as exc_info: pattern(ScalarTimesOperator, wc('a'), wc('b'), wc_name='S', mode=5) assert "Mode must be one of" in str(exc_info) with pytest.raises(ValueError) as exc_info: pattern(ScalarTimesOperator, wc('a'), wc('b'), wc_name='S', mode='1') assert "Mode must be one of" in str(exc_info)
def test_findall(): h1 = LocalSpace("h1") a = OperatorSymbol("a", hs=h1) b = OperatorSymbol("b", hs=h1) c = OperatorSymbol("c", hs=h1) h1_custom = LocalSpace("h1", local_identifiers={'Create': 'c'}) c_local = Create(hs=h1_custom) expr = 2 * (a * b * c - b * c * a + a * b) op_symbols = pattern(OperatorSymbol).findall(expr) assert len(op_symbols) == 8 assert set(op_symbols) == {a, b, c} op = wc(head=Operator) three_factors = pattern(OperatorTimes, op, op, op).findall(expr) assert three_factors == [a * b * c, b * c * a] assert len(pattern(LocalOperator).findall(expr)) == 0 assert ( len(pattern(LocalOperator).findall(expr.substitute({c: c_local}))) == 2 )
def testMatch(self): A = wc("A", head=SuperOperator) a = SuperOperatorSymbol("a", hs="hs") b = SuperOperatorSymbol("b", hs="hs") b2 = SuperOperatorSymbol("b", hs="hs") assert b == b2 assert A.match(a) assert A.match(a)['A'] == a assert A.match(b) assert A.match(b)['A'] == b expr = ProtoExpr(args=[b, b], kwargs={}) pat = pattern_head(A, A) assert pat.match(expr) assert pat.match(expr)['A'] == b expr = ProtoExpr(args=[b, b2], kwargs={}) pat = pattern_head(A, A) assert pat.match(expr) assert pat.match(expr)['A'] == b
def test_wc(): """Test that the wc() constructor produces the equivalent Pattern instance""" patterns = [ ( wc(), Pattern( head=None, args=None, kwargs=None, mode=Pattern.single, wc_name=None, conditions=None, ), ), ( wc('a'), Pattern( head=None, args=None, kwargs=None, mode=Pattern.single, wc_name='a', conditions=None, ), ), (wc('a_'), wc('a')), ( wc('a__'), Pattern( head=None, args=None, kwargs=None, mode=Pattern.one_or_more, wc_name='a', conditions=None, ), ), ( wc('a___'), Pattern( head=None, args=None, kwargs=None, mode=Pattern.zero_or_more, wc_name='a', conditions=None, ), ), ( wc('a', head=int), Pattern( head=int, args=None, kwargs=None, mode=Pattern.single, wc_name='a', conditions=None, ), ), ( wc('a', head=(int, float)), Pattern( head=(int, float), args=None, kwargs=None, mode=Pattern.single, wc_name='a', conditions=None, ), ), ] for pat1, pat2 in patterns: print(repr(pat1)) assert pat1 == pat2 with pytest.raises(ValueError): wc("____")
assert pat1 == pat2 # test expressions two_t = 2 * Symbol('t') two_O = 2 * OperatorSymbol('O', hs=FullSpace) proto_two_O = ProtoExpr([2, OperatorSymbol('O', hs=FullSpace)], {}) proto_kwargs = ProtoExpr([1, 2], {'a': '3', 'b': 4}) proto_kw_only = ProtoExpr([], {'a': 1, 'b': 2}) proto_ints2 = ProtoExpr([1, 2], {}) proto_ints3 = ProtoExpr([1, 2, 3], {}) proto_ints4 = ProtoExpr([1, 2, 3, 4], {}) proto_ints5 = ProtoExpr([1, 2, 3, 4, 5], {}) # test patterns and wildcards wc_a_int_2 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 2]) wc_a_int_3 = wc('a', head=(ScalarValue, int), conditions=[lambda i: i == 3]) wc_a_int = wc('a', head=int) wc_label_str = wc('label', head=str) wc_hs = wc('space', head=HilbertSpace) pattern_two_O = pattern( ScalarTimesOperator, wc_a_int_2, pattern(OperatorSymbol, wc_label_str, hs=wc_hs), ) pattern_two_O_head = pattern_head( wc_a_int_2, pattern(OperatorSymbol, wc_label_str, hs=wc_hs) ) pattern_two_O_expr = pattern( ScalarTimesOperator, wc_a_int_2, OperatorSymbol('O', hs=FullSpace) )