예제 #1
0
def match(target_str, pattern_str):
    'Apply all pre-processing, then pattern matcher'
    target_ast = ast.parse(target_str, mode="eval").body
    target_ast = pre_processing.all_preprocessings(target_ast)
    target_ast = Flattening(ast.Add).visit(target_ast)
    pattern_ast = ast.parse(pattern_str, mode="eval").body
    pattern_ast = pre_processing.all_preprocessings(pattern_ast)
    pattern_ast = Flattening(ast.Add).visit(pattern_ast)
    return PatternMatcher(target_ast).visit(target_ast, pattern_ast)
예제 #2
0
 def generic_ConstFolding(self, origstring, refstring, nbits, lvl=False):
     'Generic test for ConstFolding transformer'
     orig = ast.parse(origstring)
     ref = ast.parse(refstring)
     if lvl:
         orig = Flattening().visit(orig)
         ref = Flattening().visit(orig)
     orig = asttools.ConstFolding(orig, nbits).visit(orig)
     self.assertTrue(asttools.Comparator().visit(orig, ref))
예제 #3
0
def replace(target_str, pattern_str, replacement_str):
    'Apply pre-processing and replace'
    target_ast = ast.parse(target_str, mode="eval").body
    target_ast = pre_processing.all_preprocessings(target_ast)
    target_ast = Flattening(ast.Add).visit(target_ast)
    patt_ast = ast.parse(pattern_str, mode="eval").body
    patt_ast = pre_processing.all_preprocessings(patt_ast)
    patt_ast = Flattening(ast.Add).visit(patt_ast)
    rep_ast = ast.parse(replacement_str)
    rep = PatternReplacement(patt_ast, target_ast, rep_ast)
    return rep.visit(target_ast)
예제 #4
0
    def test_afterSubMult(self):
        'Tests after SubToMult pre-processing'

        tests = [("1 + 2 - 3",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3))
                  ])),
                 ("1 + 2 - 3 + 4",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(3)),
                      ast.Num(4)
                  ])),
                 ("(1 + 2) - (3 + 4)",
                  ast.BoolOp(ast.Add(), [
                      ast.Num(1),
                      ast.Num(2),
                      ast.BinOp(ast.Num(-1), ast.Mult(),
                                ast.BinOp(ast.Num(3), ast.Add(), ast.Num(4)))
                  ]))]
        for teststring, ref_ast in tests:
            test_ast = ast.parse(teststring, mode="eval").body
            test_ast = pre_processing.all_preprocessings(test_ast)
            test_ast = Flattening(ast.Add).visit(test_ast)
            self.assertTrue(Comparator().visit(test_ast, ref_ast))
예제 #5
0
 def test_with_funcs(self):
     'Tests with functions'
     tests = [
         ("f(1 + 1 + 1)",
          ast.Call(ast.Name('f', ast.Load()),
                   [ast.BoolOp(ast.Add(), [ast.Num(n) for n in [1, 1, 1]])],
                   [], None, None)),
         ("f(1 + 1 + g(2 + 2 + 2))",
          ast.Call(ast.Name('f', ast.Load()), [
              ast.BoolOp(ast.Add(), [
                  ast.Num(1),
                  ast.Num(1),
                  ast.Call(ast.Name('g', ast.Load()), [
                      ast.BoolOp(
                          ast.Add(),
                          [ast.Num(2), ast.Num(2),
                           ast.Num(2)])
                  ], [], None, None)
              ])
          ], [], None, None)),
         ("f(8) + (a + f(8)) + f(14)",
          ast.BoolOp(ast.Add(), [
              ast.Call(ast.Name('f', ast.Load()), [ast.Num(8)], [], None,
                       None),
              ast.Name('a', ast.Load()),
              ast.Call(ast.Name('f', ast.Load()), [ast.Num(8)], [], None,
                       None),
              ast.Call(ast.Name('f', ast.Load()), [ast.Num(14)], [], None,
                       None)
          ]))
     ]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = Flattening(ast.Add).visit(test_ast)
         self.assertTrue(Comparator().visit(test_ast, ref_ast))
예제 #6
0
    def test_flattened(self):
        'Test positive matchings for flattened ast'
        pattern_string = "A + 2*B + 3*C"
        test_pos = ["x + 2*y + 3*z", "3*z + 2*y + x", "2*y + 3*z + x"]
        for input_string in test_pos:
            self.generic_test_positive(input_string, pattern_string, True)

        # actual pre-processing only flattens ADD nodes, but this test
        # is for code coverage
        test_neg = ast.parse("x ^ 2*y ^ 2*z")
        test_neg = pre_processing.all_preprocessings(ast.parse(test_neg))
        test_neg = Flattening().visit(test_neg)
        patt_ast = ast.parse(pattern_string)
        patt_ast = pre_processing.all_preprocessings(patt_ast)
        patt_ast = Flattening(ast.Add).visit(patt_ast)
        pat = pattern_matcher.PatternMatcher(test_neg)
        self.assertFalse(pat.visit(test_neg, patt_ast))
예제 #7
0
def test_flattening(expr_string, refgraph):
    'Test if BoolOp are correctly processed'
    expr_ast = ast.parse(expr_string)
    expr_ast = Flattening().visit(expr_ast)
    visitor = dag_translator.DAGTranslator(expr_ast)
    visitor.visit(expr_ast)
    graph = visitor.graph
    assert str(graph.string()) == refgraph
예제 #8
0
 def test_astform(self):
     'Tests with different types of ast'
     t1 = ast.parse("1 + 2 + 3", mode="eval").body
     t1_ref = ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2), ast.Num(3)])
     t2 = ast.parse("1 + 2 + 3", mode="eval")
     t3 = ast.parse("1 + 2 + 3").body[0]
     tests = [(t1, t1_ref), (t2, ast.Expression(t1_ref)),
              (t3, ast.Expr(t1_ref))]
     for test, ref in tests:
         ltest = Flattening().visit(test)
         self.assertTrue(Comparator().visit(ltest, ref))
예제 #9
0
    def test_flattened(self):
        'Test on flattened ast'
        patt_string = "A + 2*B + 3*C"
        rep_string = "A"
        test_pos = "3*z + x + 2*y"
        ref_ast = ast.parse("x", mode='eval').body
        output_ast = pattern_matcher.replace(test_pos, patt_string, rep_string)
        self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))

        # only ADD nodes are flattened right now, this is for code
        # coverage
        test_neg = ast.parse("3*z ^ x ^ 2*y")
        test_neg = Flattening().visit(test_neg)
        patt_ast = ast.parse("A + 3*z")
        patt_ast = Flattening().visit(patt_ast)
        rep_ast = ast.parse(rep_string)
        ref_ast = ast.parse("3*z ^ x ^ 2*y")
        ref_ast = Flattening().visit(ref_ast)
        rep = pattern_matcher.PatternReplacement(patt_ast, test_neg, rep_ast)
        output_ast = rep.visit(test_neg)
        self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))
예제 #10
0
    def test_unflattening(self):
        'Tests to see if unflattening is correct'

        tests = [("x + (3 + y)", "3 + (y + x)"), ("x*(2*z)", "2*(z*x)"),
                 ("x + (y + (z*(5*var)))", "y + (5*(var*z) + x)")]

        for test, ref in tests:
            ref_ast = ast.parse(ref)
            ast_test = ast.parse(test)
            Flattening().visit(ast_test)
            Unflattening().visit(ast_test)
            self.assertTrue(Comparator().visit(ast_test, ref_ast))
            self.assertFalse('BoolOp' in astunparse.unparse(ast_test))
예제 #11
0
    def __init__(self, nbits, rules_list=DEFAULT_RULES):
        'Init context : correspondance between variables and values'
        # pylint: disable=dangerous-default-value
        self.context = {}
        self.nbits = nbits

        self.patterns = []
        for pattern, replace in rules_list:
            patt_ast = ast.parse(pattern, mode="eval").body
            patt_ast = all_preprocessings(patt_ast, self.nbits)
            patt_ast = Flattening(ast.Add).visit(patt_ast)
            rep_ast = ast.parse(replace, mode="eval").body
            self.patterns.append((patt_ast, rep_ast))
예제 #12
0
 def simplify(self, expr_ast, nbits):
     'Apply pattern matching and arithmetic simplification'
     expr_ast = arithm_simpl.run(expr_ast, nbits)
     expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast)
     if DEBUG:
         print "arithm simpl: "
         print unparse(expr_ast)
     if DEBUG:
         print "before matching: "
         print unparse(expr_ast)
     expr_ast = all_preprocessings(expr_ast, self.nbits)
     # only flattening ADD nodes because of traditionnal MBA patterns
     expr_ast = Flattening(ast.Add).visit(expr_ast)
     for pattern, repl in self.patterns:
         rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl)
         new_ast = rep.visit(deepcopy(expr_ast))
         if not asttools.Comparator().visit(new_ast, expr_ast):
             if DEBUG:
                 print "replaced! "
                 dispat = deepcopy(pattern)
                 dispat = Unflattening().visit(dispat)
                 print "pattern:  ", unparse(dispat)
                 disnew = deepcopy(new_ast)
                 disnew = Unflattening().visit(disnew)
                 print "after:    ", unparse(disnew)
                 print ""
             expr_ast = new_ast
             break
     # bitwise simplification: this is a ugly hack, should be
     # "generalized"
     expr_ast = Flattening(ast.BitXor).visit(expr_ast)
     expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast)
     expr_ast = Unflattening().visit(expr_ast)
     if DEBUG:
         print "after PM: "
         print unparse(expr_ast)
     return expr_ast
예제 #13
0
def main(argv):
    'Parse option and arguments and translate AST to DOT graph'
    random.seed(time.time())
    parser = argparse.ArgumentParser()
    parser.add_argument("input",
                        type=str,
                        help="python file containing expression to translate" +
                        "OR directly the python expression")
    parser.add_argument("-d",
                        "--draw",
                        action="store_true",
                        help="draw the corresponding graph")
    parser.add_argument("--no-cse", action="store_true", help="deactivate cse")
    parser.add_argument("--no-file",
                        action="store_true",
                        help="deactivate writing in a output file" +
                        " (useful for tests)")

    args = parser.parse_args(argv)

    if os.path.isfile(args.input):
        input_file = open(args.input, 'r')
        filename = args.input[:-3]
        input_ast = ast.parse(input_file.read())
    else:
        # if input is not a file, then it's considered as an expression
        input_ast = ast.parse(args.input)
        filename = "your_output_%d" % random.randint(0, 99)

    if not args.no_cse:
        input_ast = cse.apply_cse(input_ast)[1]

    input_ast = Flattening().visit(input_ast)
    visitor = DAGTranslator(input_ast)
    visitor.visit(input_ast)
    graph = visitor.graph
    graph.subgraph(list(visitor.variables), rank="same")

    if not args.no_file:
        graph.write("%s.dot" % filename)

    if args.draw:
        graph.layout(prog="dot")
        graph.draw("%s.pdf" % filename)

    print "Number of nodes:", len(graph)
    print "Alternation of types:", visitor.alternation
    print "your output is named:", filename
    return graph
예제 #14
0
 def test_withUnaryOp(self):
     'Test with UnaryOp involved'
     tests = [
         ("5 + (-(6 + 2)) + 3",
          ast.BoolOp(ast.Add(), [
              ast.Num(5),
              ast.UnaryOp(ast.USub(),
                          ast.BinOp(ast.Num(6), ast.Add(), ast.Num(2))),
              ast.Num(3)
          ]))
     ]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = Flattening(ast.Add).visit(test_ast)
         self.assertTrue(Comparator().visit(test_ast, ref_ast))
예제 #15
0
 def test_differentops(self):
     'Test with other types of operators'
     tests = [
         ("(3 & 5 & 6)",
          ast.BoolOp(
              ast.BitAnd(),
              [ast.Num(3), ast.Num(5), ast.Num(6)])),
         ("(1 ^ 2 ^ 3) - 4",
          ast.BinOp(
              ast.BoolOp(ast.BitXor(),
                         [ast.Num(1), ast.Num(2),
                          ast.Num(3)]), ast.Add(),
              ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(4)))),
         ("((1 + 2 + 3) & (4 + 5))",
          ast.BinOp(
              ast.BoolOp(ast.Add(),
                         [ast.Num(1), ast.Num(2),
                          ast.Num(3)]), ast.BitAnd(),
              ast.BinOp(ast.Num(4), ast.Add(), ast.Num(5)))),
         ("(1 & 2 & 3) - (4 & 5)",
          ast.BinOp(
              ast.BoolOp(ast.BitAnd(),
                         [ast.Num(1), ast.Num(2),
                          ast.Num(3)]), ast.Add(),
              ast.BinOp(ast.Num(-1), ast.Mult(),
                        ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))),
         ("(1 & 2 & 3) << (4 & 5)",
          ast.BinOp(
              ast.BoolOp(ast.BitAnd(),
                         [ast.Num(1), ast.Num(2),
                          ast.Num(3)]), ast.LShift(),
              ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))
     ]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = pre_processing.all_preprocessings(test_ast)
         test_ast = Flattening().visit(test_ast)
         self.assertTrue(Comparator().visit(test_ast, ref_ast))
예제 #16
0
 def loop_simplify(self, node):
     'Simplifying loop to reach fixpoint'
     old_value = deepcopy(node.value)
     old_value = Flattening().visit(old_value)
     node.value = self.simplify(node.value, self.nbits)
     copyvalue = deepcopy(node.value)
     copyvalue = Flattening().visit(copyvalue)
     # simplify until fixpoint is reached
     while not asttools.Comparator().visit(old_value, copyvalue):
         old_value = deepcopy(node.value)
         node.value = self.simplify(node.value, self.nbits)
         copyvalue = deepcopy(node.value)
         if len(unparse(copyvalue)) > len(unparse(old_value)):
             node.value = deepcopy(old_value)
             break
         copyvalue = Flattening().visit(copyvalue)
         old_value = Flattening().visit(old_value)
         if asttools.Comparator().visit(old_value, copyvalue):
             old_value = deepcopy(node.value)
             node.value = NotToInv().visit(node.value)
             node.value = self.simplify(node.value, self.nbits)
             copyvalue = deepcopy(node.value)
             # discard if NotToInv increased the size
             if len(unparse(copyvalue)) >= len(unparse(old_value)):
                 node.value = deepcopy(old_value)
                 copyvalue = deepcopy(node.value)
             copyvalue = Flattening().visit(copyvalue)
             old_value = Flattening().visit(old_value)
         if DEBUG:
             print "-" * 80
     # final arithmetic simplification to clean output of matching
     node.value = arithm_simpl.run(node.value, self.nbits)
     asttools.GetConstMod(self.nbits).visit(node.value)
     if DEBUG:
         print "arithm simpl: "
         print unparse(node.value)
         print ""
         print "-" * 80
     return node
예제 #17
0
 def generic_flattening(self, refstring_list, result):
     'Test matching of flattened AST and ref AST'
     for refstring in refstring_list:
         ref = ast.parse(refstring, mode="eval").body
         ref = Flattening().visit(ref)
         self.assertTrue(Comparator().visit(ref, result))