def test_different_constraints_on_commutative_operation(): c1 = CustomConstraint(lambda x: len(str(x)) > 1) c2 = CustomConstraint(lambda x: len(str(x)) == 1) pattern1 = Pattern(f_c(x_), c1) pattern2 = Pattern(f_c(x_), c2) pattern3 = Pattern(f_c(x_, b), c1) pattern4 = Pattern(f_c(x_, b), c2) matcher = ManyToOneMatcher(pattern1, pattern2, pattern3, pattern4) subject = f_c(a) results = list(matcher.match(subject)) assert len(results) == 1 assert results[0][0] == pattern2 assert results[0][1] == {'x': a} subject = f_c(Symbol('longer'), b) results = sorted(matcher.match(subject)) assert len(results) == 1 assert results[0][0] == pattern3 assert results[0][1] == {'x': Symbol('longer')} subject = f_c(a, b) results = list(matcher.match(subject)) assert len(results) == 1 assert results[0][0] == pattern4 assert results[0][1] == {'x': a}
def test_different_pattern_different_label(): matcher = ManyToOneMatcher() matcher.add(Pattern(a), 42) matcher.add(Pattern(x_), 23) result = sorted( (l, sorted(map(tuple, s.items()))) for l, s in matcher.match(a)) assert result == [(23, [('x', a)]), (42, [])]
def test_add_duplicate_pattern_with_different_constraint(): pattern1 = Pattern(f(a)) pattern2 = Pattern(f(a), MockConstraint(False)) matcher = ManyToOneMatcher() matcher.add(pattern1) matcher.add(pattern2) assert len(matcher.patterns) == 2
def test_different_constraints_no_match_on_operation(): c1 = CustomConstraint(lambda x: x == a) c2 = CustomConstraint(lambda x: x == b) pattern1 = Pattern(f(x_), c1) pattern2 = Pattern(f(x_), c2) matcher = ManyToOneMatcher(pattern1, pattern2) subject = f(c) results = list(matcher.match(subject)) assert len(results) == 0
def test_same_commutative_but_different_pattern(): pattern1 = Pattern(f(f_c(x_), a)) pattern2 = Pattern(f(f_c(x_), b)) matcher = ManyToOneMatcher(pattern1, pattern2) subject = f(f_c(a), a) result = list(matcher.match(subject)) assert result == [(pattern1, {'x': a})] subject = f(f_c(a), b) result = list(matcher.match(subject)) assert result == [(pattern2, {'x': a})]
def test_variable_expression_match_error(): net = DiscriminationNet() pattern = Pattern(f(x_)) net.add(pattern) with pytest.raises(TypeError): list(net.match(pattern))
def test_code_generation_many_to_one(subject, patterns): patterns = [Pattern(p) for p in patterns] matcher = ManyToOneMatcher(*patterns) generator = CodeGenerator(matcher) gc, code = generator.generate_code() code = GENERATED_TEMPLATE.format(gc, code) compiled = compile(code, '', 'exec') module = ModuleType('generated_code') print('=' * 80) print(code) print('=' * 80) exec(compiled, module.__dict__) for pattern in patterns: print(pattern) matches = list(module.match_root(subject)) for i, pattern in enumerate(patterns): expected_matches = PARAM_MATCHES[subject, pattern.expression] for expected_match in expected_matches: assert ( i, expected_match ) in matches, "Subject {!s} and pattern {!s} did not yield the match {!s} but were supposed to".format( subject, pattern, expected_match) while (i, expected_match) in matches: matches.remove((i, expected_match)) assert matches == [], "Subject {!s} and pattern {!s} yielded unexpected matches".format( subject, pattern)
class TestSubstitute: @pytest.mark.parametrize( ' expression, substitution, expected_result, replaced', [ (a, {}, a, False), (a, {'x': b}, a, False), (x_, {'x': b}, b, True), (x_, {'x': [a, b]}, [a, b], True), (y_, {'x': b}, y_, False), (f(x_), {'x': b}, f(b), True), (f(x_), {'y': b}, f(x_), False), (f(x_), {}, f(x_), False), (f(a, x_), {'x': b}, f(a, b), True), (f(x_), {'x': [a, b]}, f(a, b), True), (f(x_), {'x': []}, f(), True), (f(x_, c), {'x': [a, b]}, f(a, b, c), True), (f(x_, y_), {'x': a, 'y': b}, f(a, b), True), (f(x_, y_), {'x': [a, c], 'y': b}, f(a, c, b), True), (f(x_, y_), {'x': a, 'y': [b, c]}, f(a, b, c), True), (Pattern(f(x_)), {'x': a}, f(a), True) ] ) # yapf: disable def test_substitute(self, expression, substitution, expected_result, replaced): result = substitute(expression, substitution) assert result == expected_result, "Substitution did not yield expected result" if replaced: assert result is not expression, "When substituting, the original expression may not be modified" else: assert result is expression, "When nothing is substituted, the original expression has to be returned"
def test_randomized_product_net(patterns): assume(all(not isinstance(p, Atom) for p in patterns)) patterns = [Pattern(p) for p in patterns] net = DiscriminationNet() exprs = [] for pattern in patterns: net.add(pattern) flatterm = [] for term in FlatTerm(pattern.expression): if isinstance(term, Wildcard): args = [ random.choice(CONSTANT_EXPRESSIONS) for _ in range(term.min_count) ] flatterm.extend(args) elif is_symbol_wildcard(term): flatterm.append(random.choice(CONSTANT_EXPRESSIONS)) else: flatterm.append(term) if not flatterm: flatterm = [random.choice(CONSTANT_EXPRESSIONS)] exprs.append(flatterm) for index, (pattern, expr) in enumerate(zip(patterns, exprs)): result = net._match(expr) assert index in result, "{!s} did not match {!s} in the DiscriminationNet".format( pattern, expr)
def test_grouped(): pattern1 = Pattern(a, MockConstraint(True)) pattern2 = Pattern(a, MockConstraint(True)) pattern3 = Pattern(x_, MockConstraint(True)) matcher = ManyToOneMatcher(pattern1, pattern2, pattern3) result = [[p for p, _ in ps] for ps in matcher.match(a).grouped()] assert len(result) == 2 for res in result: if len(res) == 2: assert pattern1 in res assert pattern2 in res elif len(res) == 1: assert pattern3 in res else: assert False, "Wrong number of grouped matches"
def test_add_duplicate_pattern(): pattern = Pattern(f(a)) matcher = ManyToOneMatcher() matcher.add(pattern) matcher.add(pattern) assert len(matcher.patterns) == 1
def test_dataclass_operation_subclass(): x1_ = Wildcard.dot("x1") x2_ = Wildcard.dot("x2") matches = match(MySum(Symbol("foo"), Symbol("bar")), Pattern(MySum(x1_, x2_))) subst, = list(matches) assert subst["x1"].name == "foo" assert subst["x2"].name == "bar"
def test_match_anywhere(expression, pattern, expected_results): expression = expression pattern = Pattern(pattern) results = list(match_anywhere(expression, pattern)) assert len(results) == len(expected_results), "Invalid number of results" for result in expected_results: assert result in results, "Results differ from expected"
def test_dict_match(match, expression, pattern, expected_matches): expression = expression pattern = Pattern(pattern) result = list(match(expression, pattern)) for expected_match in expected_matches: assert expected_match in result, "Expression {!s} and {!s} did not yield the match {!s} but were supposed to".format( expression, pattern, expected_match) for result_match in result: assert result_match in expected_matches, "Expression {!s} and {!s} yielded the unexpected match {!s}".format( expression, pattern, result_match)
def test_same_pattern_different_label(): pattern = Pattern(a) matcher = ManyToOneMatcher() matcher.add(pattern, 42) matcher.add(pattern, 23) result = sorted( (l, sorted(map(tuple, s.items()))) for l, s in matcher.match(a)) assert result == [(23, []), (42, [])]
def test_generate_net_and_match(pattern, expr, is_match): net = DiscriminationNet() index = net.add(Pattern(pattern)) result = net._match(expr) if is_match: assert result == [index], "Matching failed for {!s} and {!s}".format( pattern, expr) else: assert result == [], "Matching should fail for {!s} and {!s}".format( pattern, expr)
def assert_match_as_expected(match, subject, pattern, expected_matches): pattern = Pattern(pattern) matches = list(match(subject, pattern)) assert len(matches) == len( expected_matches), 'Unexpected number of matches' for expected_match in expected_matches: assert expected_match in matches, "Subject {!s} and pattern {!s} did not yield the match {!s} but were supposed to".format( subject, pattern, expected_match) for match in matches: assert match in expected_matches, "Subject {!s} and pattern {!s} yielded the unexpected match {!s}".format( subject, pattern, match)
def test_sequence_matcher_match(): PATTERNS = [ Pattern(f(___, x_, x_, ___)), Pattern(f(z___, a, b, ___)), Pattern(f(___, a, c, z___)), Pattern(f(z___, a, c, z___)), ] matcher = SequenceMatcher(*PATTERNS) expr = f(a, b, c, a, a, b, a, c, b) matches = list(matcher.match(expr)) assert len(matches) == 4 assert (PATTERNS[0], {'x': a}) in matches assert (PATTERNS[1], {'z': ()}) in matches assert (PATTERNS[1], {'z': (a, b, c, a)}) in matches assert (PATTERNS[2], {'z': (b, )}) in matches assert list(matcher.match(a)) == []
def test_many_to_one(subject, patterns): patterns = [Pattern(p) for p in patterns] matcher = ManyToOneMatcher(*patterns) matches = list(matcher.match(subject)) for pattern in patterns: expected_matches = PARAM_MATCHES[subject, pattern.expression] for expected_match in expected_matches: assert ( pattern, expected_match ) in matches, "Subject {!s} and pattern {!s} did not yield the match {!s} but were supposed to".format( subject, pattern, expected_match) while (pattern, expected_match) in matches: matches.remove((pattern, expected_match)) assert matches == [], "Subject {!s} and pattern {!s} yielded unexpected matches".format( subject, pattern)
class TestSubstitute: @pytest.mark.parametrize( ' expression, substitution, expected_result, replaced', [ (a, {}, a, False), (a, {'x': b}, a, False), (x_, {'x': b}, b, True), (x_, {'x': [a, b]}, [a, b], True), (y_, {'x': b}, y_, False), (f(x_), {'x': b}, f(b), True), (f(x_), {'y': b}, f(x_), False), (f(x_), {}, f(x_), False), (f(a, x_), {'x': b}, f(a, b), True), (f(x_), {'x': [a, b]}, f(a, b), True), (f(x_), {'x': []}, f(), True), (f(x_, c), {'x': [a, b]}, f(a, b, c), True), (f(x_, y_), {'x': a, 'y': b}, f(a, b), True), (f(x_, y_), {'x': [a, c], 'y': b}, f(a, c, b), True), (f(x_, y_), {'x': a, 'y': [b, c]}, f(a, b, c), True), (Pattern(f(x_)), {'x': a}, f(a), True) ] ) # yapf: disable def test_substitute(self, expression, substitution, expected_result, replaced): result = substitute(expression, substitution) assert result == expected_result, "Substitution did not yield expected result" if replaced: assert result is not expression, "When substituting, the original expression may not be modified" else: assert result is expression, "When nothing is substituted, the original expression has to be returned" def test_substitute_custom_sorting_key(self): # Check custom sorting key for elements in Multiset when the argument # is passed to `substitute`. # Reverse alphabetical sorting: sort_key = lambda x: -ord(str(x)) expression = f(x_, y_) substitution = {'x': a, 'y': Multiset([b, c])} result = substitute(expression, substitution, sort_key) assert result == f(a, c, b) assert result != f(a, b, c) # Remove custom sorting key, sorting is again alphabetical: result = substitute(expression, substitution) assert result != f(a, c, b) assert result == f(a, b, c)
def test_different_pattern_same_constraint(c1, c2): constr1 = CustomConstraint(lambda x: c1) constr2 = CustomConstraint(lambda x: c2) constr3 = CustomConstraint(lambda x: True) patterns = [ Pattern(f2(x_, a), constr3), Pattern(f(a, a, x_), constr3), Pattern(f(a, x_), constr1), Pattern(f(x_, a), constr2), Pattern(f(a, x_, b), constr1), Pattern(f(x_, a, b), constr1), ] subject = f(a, a) matcher = ManyToOneMatcher(*patterns) results = list(matcher.match(subject)) assert len(results) == int(c1) + int(c2)
flatterm.append(random.choice(CONSTANT_EXPRESSIONS)) else: flatterm.append(term) if not flatterm: flatterm = [random.choice(CONSTANT_EXPRESSIONS)] exprs.append(flatterm) for index, (pattern, expr) in enumerate(zip(patterns, exprs)): result = net._match(expr) assert index in result, "{!s} did not match {!s} in the DiscriminationNet".format( pattern, expr) PRODUCT_NET_PATTERNS = [ Pattern(f(a, _, _)), Pattern(f(_, a, _)), Pattern(f(_, _, a)), Pattern(f(__)), Pattern(f(f2(_, ___))), Pattern(f(___, f2(_))), Pattern(_), ] PRODUCT_NET_EXPRESSIONS = [ f(a, a, a), f(b, a, a), f(a, b, a), f(a, a, b), f(f2(a), a, a), f(f2(a), a, f2(b)),
def test_one_identity_optional_commutativity(): Int = Operation.new('Int', Arity.binary) Add = Operation.new('+', Arity.variadic, 'Add', infix=True, associative=True, commutative=True, one_identity=True) Mul = Operation.new('*', Arity.variadic, 'Mul', infix=True, associative=True, commutative=True, one_identity=True) Pow = Operation.new('^', Arity.binary, 'Pow', infix=True) class Integer(Symbol): def __init__(self, value): super().__init__(str(value)) i0 = Integer(0) i1 = Integer(1) i2 = Integer(2) x_, m_, a_ = map(Wildcard.dot, 'xma') x, m = map(Symbol, 'xm') a0_ = Wildcard.optional('a', i0) b1_ = Wildcard.optional('b', i1) c0_ = Wildcard.optional('c', i0) d1_ = Wildcard.optional('d', i1) m1_ = Wildcard.optional('m', i1) n1_ = Wildcard.optional('n', i1) pattern22 = Pattern( Int( Mul(Pow(Add(a0_, Mul(b1_, x_)), m1_), Pow(Add(c0_, Mul(d1_, x_)), n1_)), x_)) pattern23 = Pattern( Int( Mul(Pow(Add(a_, Mul(b1_, x_)), m1_), Pow(Add(c0_, Mul(d1_, x_)), n1_)), x_)) matcher = ManyToOneMatcher() matcher.add(pattern22, 22) matcher.add(pattern23, 23) subject = Int(Mul(Pow(Add(Mul(b, x), a), i2), Pow(x, i2)), x) result = sorted( (l, sorted(map(tuple, s.items()))) for l, s in matcher.match(subject)) assert result == [ (22, [('a', i0), ('b', i1), ('c', a), ('d', b), ('m', i2), ('n', i2), ('x', x)]), (22, [('a', a), ('b', b), ('c', i0), ('d', i1), ('m', i2), ('n', i2), ('x', x)]), (23, [('a', a), ('b', b), ('c', i0), ('d', i1), ('m', i2), ('n', i2), ('x', x)]), ]
def test_is_match(expr, pattern, do_match): assert is_match(expr, Pattern(pattern)) == do_match
def test_logic_simplify(replacer): LAnd = Operation.new('and', Arity.variadic, 'LAnd', associative=True, one_identity=True, commutative=True) LOr = Operation.new('or', Arity.variadic, 'LOr', associative=True, one_identity=True, commutative=True) LXor = Operation.new('xor', Arity.variadic, 'LXor', associative=True, one_identity=True, commutative=True) LNot = Operation.new('not', Arity.unary, 'LNot') LImplies = Operation.new('implies', Arity.binary, 'LImplies') Iff = Operation.new('iff', Arity.binary, 'Iff') ___ = Wildcard.star() a1 = Symbol('a1') a2 = Symbol('a2') a3 = Symbol('a3') a4 = Symbol('a4') a5 = Symbol('a5') a6 = Symbol('a6') a7 = Symbol('a7') a8 = Symbol('a8') a9 = Symbol('a9') a10 = Symbol('a10') a11 = Symbol('a11') LBot = Symbol(u'⊥') LTop = Symbol(u'⊤') expression = LImplies( LAnd( Iff( Iff(LOr(a1, a2), LOr(LNot(a3), Iff(LXor(a4, a5), LNot(LNot(LNot(a6)))))), LNot( LAnd( LAnd(a7, a8), LNot( LXor( LXor(LOr(a9, LAnd(a10, a11)), a2), LAnd(LAnd(a11, LXor(a2, Iff(a5, a5))), LXor(LXor(a7, a7), Iff(a9, a4))) ) ) ) ) ), LImplies( Iff( Iff(LOr(a1, a2), LOr(LNot(a3), Iff(LXor(a4, a5), LNot(LNot(LNot(a6)))))), LNot( LAnd( LAnd(a7, a8), LNot( LXor( LXor(LOr(a9, LAnd(a10, a11)), a2), LAnd(LAnd(a11, LXor(a2, Iff(a5, a5))), LXor(LXor(a7, a7), Iff(a9, a4))) ) ) ) ) ), LNot( LAnd( LImplies( LAnd(a1, a2), LNot( LXor( LOr( LOr( LXor(LImplies(LAnd(a3, a4), LImplies(a5, a6)), LOr(a7, a8)), LXor(Iff(a9, a10), a11) ), LXor(LXor(a2, a2), a7) ), Iff(LOr(a4, a9), LXor(LNot(a6), a6)) ) ) ), LNot(Iff(LNot(a11), LNot(a9))) ) ) ) ), LNot( LAnd( LImplies( LAnd(a1, a2), LNot( LXor( LOr( LOr( LXor(LImplies(LAnd(a3, a4), LImplies(a5, a6)), LOr(a7, a8)), LXor(Iff(a9, a10), a11) ), LXor(LXor(a2, a2), a7) ), Iff(LOr(a4, a9), LXor(LNot(a6), a6)) ) ) ), LNot(Iff(LNot(a11), LNot(a9))) ) ) ) rules = [ # xor(x,⊥) → x ReplacementRule( Pattern(LXor(x__, LBot)), lambda x: LXor(*x) ), # xor(x, x) → ⊥ ReplacementRule( Pattern(LXor(x_, x_, ___)), lambda x: LBot ), # and(x,⊤) → x ReplacementRule( Pattern(LAnd(x__, LTop)), lambda x: LAnd(*x) ), # and(x,⊥) → ⊥ ReplacementRule( Pattern(LAnd(__, LBot)), lambda: LBot ), # and(x, x) → x ReplacementRule( Pattern(LAnd(x_, x_, y___)), lambda x, y: LAnd(x, *y) ), # and(x, xor(y, z)) → xor(and(x, y), and(x, z)) ReplacementRule( Pattern(LAnd(x_, LXor(y_, z_))), lambda x, y, z: LXor(LAnd(x, y), LAnd(x, z)) ), # implies(x, y) → not(xor(x, and(x, y))) ReplacementRule( Pattern(LImplies(x_, y_)), lambda x, y: LNot(LXor(x, LAnd(x, y))) ), # not(x) → xor(x,⊤) ReplacementRule( Pattern(LNot(x_)), lambda x: LXor(x, LTop) ), # or(x, y) → xor(and(x, y), xor(x, y)) ReplacementRule( Pattern(LOr(x_, y_)), lambda x, y: LXor(LAnd(x, y), LXor(x, y)) ), # iff(x, y) → not(xor(x, y)) ReplacementRule( Pattern(Iff(x_, y_)), lambda x, y: LNot(LXor(x, y)) ), ] # yapf: disable result = replacer(expression, rules) assert result == LBot
def test_sequence_matcher_can_match(pattern, can_match): assert SequenceMatcher.can_match(Pattern(pattern)) == can_match