Ejemplo n.º 1
0
def match(target_str, pattern_str):
    'Apply all pre-processing, then pattern matcher'
    target_ast = ast.parse(target_str, mode="eval").body
    target_ast = pre_processing.all_preprocessings(target_ast)
    target_ast = asttools.LevelOperators(ast.Add).visit(target_ast)
    pattern_ast = ast.parse(pattern_str, mode="eval").body
    pattern_ast = pre_processing.all_preprocessings(pattern_ast)
    pattern_ast = asttools.LevelOperators(ast.Add).visit(pattern_ast)
    return PatternMatcher(target_ast).visit(target_ast, pattern_ast)
Ejemplo n.º 2
0
def replace(target_str, pattern_str, replacement_str):
    'Apply pre-processing and replace'
    target_ast = ast.parse(target_str, mode="eval").body
    target_ast = pre_processing.all_preprocessings(target_ast)
    target_ast = asttools.LevelOperators(ast.Add).visit(target_ast)
    patt_ast = ast.parse(pattern_str, mode="eval").body
    patt_ast = pre_processing.all_preprocessings(patt_ast)
    patt_ast = asttools.LevelOperators(ast.Add).visit(patt_ast)
    rep_ast = ast.parse(replacement_str)
    rep = PatternReplacement(patt_ast, target_ast, rep_ast)
    return rep.visit(target_ast)
Ejemplo n.º 3
0
    def test_afterSubMult(self):
        'Tests after SubToMult pre-processing'

        tests = [("1 + 2 - 3",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3))
                  ])),
                 ("1 + 2 - 3 + 4",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3)),
                      ast.Num(4)
                  ])),
                 ("(1 + 2) - (3 + 4)",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(),
                                ast.BinOp(ast.Num(3), ast.Add(), ast.Num(4)))
                  ]))]
        for teststring, ref_ast in tests:
            test_ast = ast.parse(teststring, mode="eval").body
            test_ast = pre_processing.all_preprocessings(test_ast)
            test_ast = Flattening(ast.Add).visit(test_ast)
            self.assertTrue(Comparator().visit(test_ast, ref_ast))
Ejemplo n.º 4
0
 def test_differentops(self):
     'Test with other types of operators'
     tests = [("(3 & 5 & 6)",
               ast.BoolOp(ast.BitAnd(),
                          [ast.Num(3), ast.Num(5), ast.Num(6)])),
              ("(1 ^ 2 ^ 3) - 4",
               ast.BinOp(ast.BoolOp(ast.BitXor(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.Add(),
                         ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(4)))),
              ("((1 + 2 + 3) & (4 + 5))",
               ast.BinOp(ast.BoolOp(ast.Add(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.BitAnd(),
                         ast.BinOp(ast.Num(4), ast.Add(), ast.Num(5)))),
              ("(1 & 2 & 3) - (4 & 5)",
               ast.BinOp(ast.BoolOp(ast.BitAnd(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.Add(),
                         ast.BinOp(ast.Num(-1), ast.Mult(),
                                   ast.BinOp(ast.Num(4), ast.BitAnd(),
                                             ast.Num(5))))),
              ("(1 & 2 & 3) << (4 & 5)",
               ast.BinOp(ast.BoolOp(ast.BitAnd(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.LShift(),
                         ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = pre_processing.all_preprocessings(test_ast)
         test_ast = asttools.LevelOperators().visit(test_ast)
         self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
Ejemplo n.º 5
0
    def test_leveled(self):
        'Test positive matchings for leveled ast'
        pattern_string = "A + 2*B + 3*C"
        test_pos = ["x + 2*y + 3*z", "3*z + 2*y + x", "2*y + 3*z + x"]
        for input_string in test_pos:
            self.generic_test_positive(input_string, pattern_string, True)

        # actual pre-processing only level ADD nodes, but this test is
        # for code coverage
        test_neg = ast.parse("x ^ 2*y ^ 2*z")
        test_neg = pre_processing.all_preprocessings(ast.parse(test_neg))
        test_neg = asttools.LevelOperators().visit(test_neg)
        patt_ast = ast.parse(pattern_string)
        patt_ast = pre_processing.all_preprocessings(patt_ast)
        patt_ast = asttools.LevelOperators(ast.Add).visit(patt_ast)
        pat = pattern_matcher.PatternMatcher(test_neg)
        self.assertFalse(pat.visit(test_neg, patt_ast))
Ejemplo n.º 6
0
    def __init__(self, nbits, rules_list=DEFAULT_RULES):
        'Init context : correspondance between variables and values'
        # pylint: disable=dangerous-default-value
        self.context = {}
        self.nbits = nbits

        self.patterns = []
        for pattern, replace in rules_list:
            patt_ast = ast.parse(pattern, mode="eval").body
            patt_ast = all_preprocessings(patt_ast, self.nbits)
            patt_ast = Flattening(ast.Add).visit(patt_ast)
            rep_ast = ast.parse(replace, mode="eval").body
            self.patterns.append((patt_ast, rep_ast))
Ejemplo n.º 7
0
 def simplify(self, expr_ast, nbits):
     'Apply pattern matching and arithmetic simplification'
     expr_ast = arithm_simpl.run(expr_ast, nbits)
     expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast)
     if DEBUG:
         print "arithm simpl: "
         print unparse(expr_ast)
     if DEBUG:
         print "before matching: "
         print unparse(expr_ast)
     expr_ast = all_preprocessings(expr_ast, self.nbits)
     # only flattening ADD nodes because of traditionnal MBA patterns
     expr_ast = Flattening(ast.Add).visit(expr_ast)
     for pattern, repl in self.patterns:
         rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl)
         new_ast = rep.visit(deepcopy(expr_ast))
         if not asttools.Comparator().visit(new_ast, expr_ast):
             if DEBUG:
                 print "replaced! "
                 dispat = deepcopy(pattern)
                 dispat = Unflattening().visit(dispat)
                 print "pattern:  ", unparse(dispat)
                 disnew = deepcopy(new_ast)
                 disnew = Unflattening().visit(disnew)
                 print "after:    ", unparse(disnew)
                 print ""
             expr_ast = new_ast
             break
     # bitwise simplification: this is a ugly hack, should be
     # "generalized"
     expr_ast = Flattening(ast.BitXor).visit(expr_ast)
     expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast)
     expr_ast = Unflattening().visit(expr_ast)
     if DEBUG:
         print "after PM: "
         print unparse(expr_ast)
     return expr_ast