Beispiel #1
0
    def test_onBoolOp(self):
        'Tests on BoolOp'

        expr_a = ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2), ast.Num(3)])
        expr_b = ast.BoolOp(ast.Add(), [ast.Num(3), ast.Num(2), ast.Num(1)])
        self.assertTrue(asttools.Comparator().visit(expr_a, expr_b))

        expr_a = ast.BoolOp(ast.Add, [ast.Num(1), ast.BoolOp(ast.Mult(),
                                                             [ast.Num(5),
                                                              ast.Num(6)]),
                                      ast.Num(4)])
        expr_b = ast.BoolOp(ast.Add, [ast.BoolOp(ast.Mult(), [ast.Num(6),
                                                              ast.Num(5)]),
                                      ast.Num(4),
                                      ast.Num(1)])
        self.assertTrue(asttools.Comparator().visit(expr_a, expr_b))
Beispiel #2
0
    def test_comp(self):
        'Basic tests for comparison'
        comp = asttools.Comparator()

        add_a = ast.parse('x + y')
        add_b = ast.parse('x + y')
        self.assertTrue(comp.visit(add_a, add_b))
        add_c = ast.parse('y + x')
        self.assertTrue(comp.visit(add_a, add_c))

        sub_a = ast.parse('x - y')
        self.assertFalse(comp.visit(add_a, sub_a))
        sub_b = ast.parse('y - x')
        self.assertFalse(comp.visit(sub_a, sub_b))

        expr_a = ast.parse('2*(x & y) + ((a - 3) ^ 45)')
        expr_b = ast.parse('2*(x & y) + ((a - 3) ^ 45)')
        self.assertTrue(comp.visit(expr_a, expr_b))
        expr_c = ast.parse('2*(y & x) + (45 ^ (a - 3))')
        self.assertTrue(comp.visit(expr_a, expr_c))
        expr_d = ast.parse('3*(y & x) + (45 ^ (a - 3))')
        self.assertFalse(comp.visit(expr_b, expr_d))

        expr_a = ast.parse('(3*x) + 57 - (x | (-2))')
        expr_b = ast.parse('3*x + 57 - (x | (-2))')
        self.assertTrue(comp.visit(expr_a, expr_b))
Beispiel #3
0
    def test_afterSubMult(self):
        'Tests after SubToMult pre-processing'

        tests = [("1 + 2 - 3", ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2),
                                                      ast.BinOp(ast.Num(-1),
                                                                ast.Mult(),
                                                                ast.Num(3))])),
                 ("1 + 2 - 3 + 4", ast.BoolOp(ast.Add(),
                                              [ast.Num(1),
                                               ast.Num(2),
                                               ast.BinOp(ast.Num(-1),
                                                         ast.Mult(),
                                                         ast.Num(3)),
                                               ast.Num(4)])),
                 ("(1 + 2) - (3 + 4)",
                  ast.BoolOp(ast.Add(),
                             [ast.Num(1), ast.Num(2),
                              ast.BinOp(ast.Num(-1), ast.Mult(),
                                        ast.BinOp(ast.Num(3), ast.Add(),
                                                  ast.Num(4)))]))]
        for teststring, ref_ast in tests:
            test_ast = ast.parse(teststring, mode="eval").body
            test_ast = pre_processing.all_preprocessings(test_ast)
            test_ast = asttools.LevelOperators(ast.Add).visit(test_ast)
            self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
Beispiel #4
0
 def test_differentops(self):
     'Test with other types of operators'
     tests = [("(3 & 5 & 6)",
               ast.BoolOp(ast.BitAnd(),
                          [ast.Num(3), ast.Num(5), ast.Num(6)])),
              ("(1 ^ 2 ^ 3) - 4",
               ast.BinOp(ast.BoolOp(ast.BitXor(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.Add(),
                         ast.BinOp(ast.Num(-1), ast.Mult(), ast.Num(4)))),
              ("((1 + 2 + 3) & (4 + 5))",
               ast.BinOp(ast.BoolOp(ast.Add(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.BitAnd(),
                         ast.BinOp(ast.Num(4), ast.Add(), ast.Num(5)))),
              ("(1 & 2 & 3) - (4 & 5)",
               ast.BinOp(ast.BoolOp(ast.BitAnd(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.Add(),
                         ast.BinOp(ast.Num(-1), ast.Mult(),
                                   ast.BinOp(ast.Num(4), ast.BitAnd(),
                                             ast.Num(5))))),
              ("(1 & 2 & 3) << (4 & 5)",
               ast.BinOp(ast.BoolOp(ast.BitAnd(),
                                    [ast.Num(1), ast.Num(2), ast.Num(3)]),
                         ast.LShift(),
                         ast.BinOp(ast.Num(4), ast.BitAnd(), ast.Num(5))))]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = pre_processing.all_preprocessings(test_ast)
         test_ast = asttools.LevelOperators().visit(test_ast)
         self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
Beispiel #5
0
 def generic_basicCSE(self, instring, refstring):
     'Generic test for CSE: matching of CSE AST and ref AST'
     output_cse = cse.apply_cse(instring)
     output_ast = ast.parse(output_cse)
     ref_ast = ast.parse(refstring)
     # self.assertEquals(refstring, output_cse)
     self.assertTrue(asttools.Comparator().visit(ref_ast, output_ast))
Beispiel #6
0
 def test_with_funcs(self):
     'Tests with functions'
     tests = [
         ("f(1 + 1 + 1)",
          ast.Call(ast.Name('f', ast.Load()),
                   [ast.BoolOp(ast.Add(),
                               [ast.Num(n) for n in [1, 1, 1]])],
                   [],
                   None,
                   None)),
         ("f(1 + 1 + g(2 + 2 + 2))",
          ast.Call(ast.Name('f', ast.Load()),
                   [ast.BoolOp(ast.Add(),
                               [ast.Num(1),
                                ast.Num(1),
                                ast.Call(ast.Name('g', ast.Load()),
                                         [ast.BoolOp(ast.Add(),
                                                     [ast.Num(2),
                                                      ast.Num(2),
                                                         ast.Num(2)])],
                                         [],
                                         None,
                                         None)])],
                   [],
                   None,
                   None))]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = asttools.LevelOperators(ast.Add).visit(test_ast)
         self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
Beispiel #7
0
 def test_root(self):
     'Test with different types of roots'
     patt_ast = ast.parse("A + B", mode='eval')
     input_ast = ast.parse("x + y", mode='eval')
     ref_ast = ast.parse("89", mode='eval')
     rep_ast = ast.parse("89", mode='eval')
     rep = pattern_matcher.PatternReplacement(patt_ast, input_ast, rep_ast)
     input_ast = rep.visit(input_ast)
     self.assertTrue(asttools.Comparator().visit(input_ast, ref_ast))
Beispiel #8
0
 def generic_ConstFolding(self, origstring, refstring, nbits, lvl=False):
     'Generic test for ConstFolding transformer'
     orig = ast.parse(origstring)
     ref = ast.parse(refstring)
     if lvl:
         orig = Flattening().visit(orig)
         ref = Flattening().visit(orig)
     orig = asttools.ConstFolding(orig, nbits).visit(orig)
     self.assertTrue(asttools.Comparator().visit(orig, ref))
Beispiel #9
0
 def test_associativity(self):
     'Simple tests for associativity'
     pattern = "3*A + 2*B"
     replacement = "B"
     tests = [("2*x + 3*y", "x"), ("2*x + y + 3*g", "x + y")]
     for input_string, refstring in tests:
         ref_ast = ast.parse(refstring, mode="eval").body
         output_ast = pattern_matcher.replace(input_string, pattern,
                                              replacement)
         self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))
Beispiel #10
0
 def generic_test_replacement(self, tests, pattern, replacement):
     'Generic test for a list of input/output'
     for input_string, refstring in tests:
         input_ast = ast.parse(input_string)
         ref_ast = ast.parse(refstring)
         patt_ast = ast.parse(pattern)
         rep_ast = ast.parse(replacement)
         rep = pattern_matcher.PatternReplacement(patt_ast, input_ast,
                                                  rep_ast)
         input_ast = rep.visit(input_ast)
         self.assertTrue(asttools.Comparator().visit(input_ast, ref_ast))
Beispiel #11
0
 def test_astform(self):
     'Tests with different types of ast'
     t1 = ast.parse("1 + 2 + 3", mode="eval").body
     t1_ref = ast.BoolOp(ast.Add(), [ast.Num(1), ast.Num(2), ast.Num(3)])
     t2 = ast.parse("1 + 2 + 3", mode="eval")
     t3 = ast.parse("1 + 2 + 3").body[0]
     tests = [(t1, t1_ref), (t2, ast.Expression(t1_ref)),
              (t3, ast.Expr(t1_ref))]
     for test, ref in tests:
         ltest = asttools.LevelOperators().visit(test)
         self.assertTrue(asttools.Comparator().visit(ltest, ref))
Beispiel #12
0
    def test_leveled(self):
        'Test on leveled ast'
        patt_string = "A + 2*B + 3*C"
        rep_string = "A"
        test_pos = "3*z + x + 2*y"
        ref_ast = ast.parse("x", mode='eval').body
        output_ast = pattern_matcher.replace(test_pos, patt_string, rep_string)
        self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))

        # only ADD nodes are leveled right now, this is for code
        # coverage
        test_neg = ast.parse("3*z ^ x ^ 2*y")
        test_neg = asttools.LevelOperators().visit(test_neg)
        patt_ast = ast.parse("A + 3*z")
        patt_ast = asttools.LevelOperators().visit(patt_ast)
        rep_ast = ast.parse(rep_string)
        ref_ast = ast.parse("3*z ^ x ^ 2*y")
        ref_ast = asttools.LevelOperators().visit(ref_ast)
        rep = pattern_matcher.PatternReplacement(patt_ast, test_neg, rep_ast)
        output_ast = rep.visit(test_neg)
        self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))
Beispiel #13
0
 def test_withUnaryOp(self):
     'Test with UnaryOp involved'
     tests = [("5 + (-(6 + 2)) + 3",
               ast.BoolOp(ast.Add(),
                          [ast.Num(5),
                           ast.UnaryOp(ast.USub(), ast.BinOp(ast.Num(6),
                                                             ast.Add(),
                                                             ast.Num(2))),
                           ast.Num(3)]))]
     for teststring, ref_ast in tests:
         test_ast = ast.parse(teststring, mode="eval").body
         test_ast = asttools.LevelOperators(ast.Add).visit(test_ast)
         self.assertTrue(asttools.Comparator().visit(test_ast, ref_ast))
Beispiel #14
0
 def loop_simplify(self, node):
     'Simplifying loop to reach fixpoint'
     old_value = deepcopy(node.value)
     old_value = Flattening().visit(old_value)
     node.value = self.simplify(node.value, self.nbits)
     copyvalue = deepcopy(node.value)
     copyvalue = Flattening().visit(copyvalue)
     # simplify until fixpoint is reached
     while not asttools.Comparator().visit(old_value, copyvalue):
         old_value = deepcopy(node.value)
         node.value = self.simplify(node.value, self.nbits)
         copyvalue = deepcopy(node.value)
         if len(unparse(copyvalue)) > len(unparse(old_value)):
             node.value = deepcopy(old_value)
             break
         copyvalue = Flattening().visit(copyvalue)
         old_value = Flattening().visit(old_value)
         if asttools.Comparator().visit(old_value, copyvalue):
             old_value = deepcopy(node.value)
             node.value = NotToInv().visit(node.value)
             node.value = self.simplify(node.value, self.nbits)
             copyvalue = deepcopy(node.value)
             # discard if NotToInv increased the size
             if len(unparse(copyvalue)) >= len(unparse(old_value)):
                 node.value = deepcopy(old_value)
                 copyvalue = deepcopy(node.value)
             copyvalue = Flattening().visit(copyvalue)
             old_value = Flattening().visit(old_value)
         if DEBUG:
             print "-" * 80
     # final arithmetic simplification to clean output of matching
     node.value = arithm_simpl.run(node.value, self.nbits)
     asttools.GetConstMod(self.nbits).visit(node.value)
     if DEBUG:
         print "arithm simpl: "
         print unparse(node.value)
         print ""
         print "-" * 80
     return node
Beispiel #15
0
    def test_unleveling(self):
        'Tests to see if unleveling is correct'

        tests = [("x + (3 + y)", "3 + (y + x)"),
                 ("x*(2*z)", "2*(z*x)"),
                 ("x + (y + (z*(5*var)))", "y + (5*(var*z) + x)")]

        for test, ref in tests:
            ref_ast = ast.parse(ref)
            ast_test = ast.parse(test)
            asttools.LevelOperators().visit(ast_test)
            asttools.Unleveling().visit(ast_test)
            self.assertTrue(asttools.Comparator().visit(ast_test, ref_ast))
            self.assertFalse('BoolOp' in astunparse.unparse(ast_test))
 def check_wildcard(self, target, pattern):
     'Check wildcard value or affect it'
     if pattern.id in self.wildcards:
         wild_value = self.wildcards[pattern.id]
         exact_comp = asttools.Comparator().visit(wild_value, target)
         if exact_comp:
             return True
         if FLEXIBLE:
             return self.check_eq_z3(target, self.wildcards[pattern.id])
         else:
             return False
     else:
         self.wildcards[pattern.id] = target
         return True
Beispiel #17
0
 def visit_Expr(self, node):
     'Simplify expression and replace it'
     old_value = deepcopy(node.value)
     old_value = asttools.LevelOperators().visit(old_value)
     node.value = self.simplify(node.value, self.nbits)
     copyvalue = deepcopy(node.value)
     copyvalue = asttools.LevelOperators().visit(copyvalue)
     # simplify until fixpoint is reached
     while not asttools.Comparator().visit(old_value, copyvalue):
         old_value = deepcopy(node.value)
         old_value = asttools.LevelOperators().visit(old_value)
         node.value = self.simplify(node.value, self.nbits)
         copyvalue = deepcopy(node.value)
         copyvalue = asttools.LevelOperators().visit(copyvalue)
     return node
Beispiel #18
0
 def simplify(self, expr_ast, nbits):
     'Apply pattern matching and arithmetic simplification'
     if DEBUG:
         print "before: "
         print unparse(expr_ast)
         print ""
     expr_ast = all_target_preprocessings(expr_ast, self.nbits)
     expr_ast = asttools.LevelOperators(ast.Add).visit(expr_ast)
     for pattern, repl in self.patterns:
         rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl)
         new_ast = rep.visit(deepcopy(expr_ast))
         if DEBUG:
             if not asttools.Comparator().visit(new_ast, expr_ast):
                 print "replaced! "
                 expr_debug = deepcopy(expr_ast)
                 expr_debug = asttools.Unleveling().visit(expr_debug)
                 print unparse(expr_debug)
                 new_debug = deepcopy(new_ast)
                 new_debug = asttools.Unleveling().visit(new_debug)
                 print unparse(new_debug)
                 print "before:   ", ast.dump(expr_ast)
                 print "pattern:  ", ast.dump(pattern)
                 patt_debug = asttools.Unleveling().visit(deepcopy(pattern))
                 print unparse(patt_debug)
                 print ""
                 print ""
                 print "after:    ", ast.dump(new_ast)
                 print ""
         expr_ast = new_ast
     # bitwise simplification: this is a ugly hack, should be
     # "generalized"
     expr_ast = asttools.LevelOperators(ast.BitXor).visit(expr_ast)
     expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast)
     expr_ast = asttools.Unleveling().visit(expr_ast)
     if DEBUG:
         print "after PM: "
         print unparse(expr_ast)
         print ""
     expr_ast = arithm_simpl.run(expr_ast, nbits)
     expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast)
     if DEBUG:
         print "arithm simpl: "
         print unparse(expr_ast)
         print ""
         print "-" * 80
     return expr_ast
Beispiel #19
0
    def generic_AstCompTest(self, *args):
        """Args: (tests, transformer) with tests a list,
        or (input_string, refstring, transformer)"""

        if len(args) != 2 and len(args) != 3:
            raise Exception("generic_AstTest should be " +
                            "called with 3 or 4 arguments")
        if len(args) == 2:
            tests = args[0]
            transformer = args[1]
        else:
            tests = [(args[0], args[1])]
            transformer = args[2]
        for origstring, refstring in tests:
            orig = ast.parse(origstring)
            ref = ast.parse(refstring)
            orig = transformer.visit(orig)
            self.assertTrue(asttools.Comparator().visit(orig, ref))
Beispiel #20
0
 def check_neg(self, target, pattern):
     'Check (-1)*... pattern that could be in another form'
     if self.is_wildcard(pattern.right):
         wkey = pattern.right.id
         if isinstance(target, ast.Num):
             if wkey not in self.wildcards:
                 mod = 2**self.nbits
                 self.wildcards[wkey] = ast.Num((-target.n) % mod)
                 return True
             else:
                 wilds2 = self.wildcards[pattern.right.id]
                 num = ast.Num((-target.n) % 2**self.nbits)
                 return asttools.Comparator().visit(wilds2, num)
         else:
             if wkey not in self.wildcards:
                 self.wildcards[wkey] = ast.BinOp(ast.Num(-1), ast.Mult(),
                                                  target)
                 return True
     return self.check_eq_z3(target, pattern)
Beispiel #21
0
 def check_not(self, target, pattern):
     'Check NOT pattern node that could be in another form'
     if self.is_wildcard(pattern.operand):
         wkey = pattern.operand.id
         if isinstance(target, ast.Num):
             if wkey not in self.wildcards:
                 mod = 2**self.nbits
                 self.wildcards[wkey] = ast.Num((~target.n) % mod)
                 return True
             else:
                 wilds2 = self.wildcards[pattern.operand.id]
                 num = ast.Num((~target.n) % 2**self.nbits)
                 return asttools.Comparator().visit(wilds2, num)
         else:
             if wkey not in self.wildcards:
                 self.wildcards[wkey] = ast.UnaryOp(ast.Invert(), target)
                 return True
         return self.check_eq_z3(target, pattern)
     else:
         subpattern = pattern.operand
         newtarget = ast.UnaryOp(ast.Invert(), target)
         return self.check_eq_z3(newtarget, subpattern)
Beispiel #22
0
 def simplify(self, expr_ast, nbits):
     'Apply pattern matching and arithmetic simplification'
     expr_ast = arithm_simpl.run(expr_ast, nbits)
     expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast)
     if DEBUG:
         print "arithm simpl: "
         print unparse(expr_ast)
     if DEBUG:
         print "before matching: "
         print unparse(expr_ast)
     expr_ast = all_preprocessings(expr_ast, self.nbits)
     # only flattening ADD nodes because of traditionnal MBA patterns
     expr_ast = Flattening(ast.Add).visit(expr_ast)
     for pattern, repl in self.patterns:
         rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl)
         new_ast = rep.visit(deepcopy(expr_ast))
         if not asttools.Comparator().visit(new_ast, expr_ast):
             if DEBUG:
                 print "replaced! "
                 dispat = deepcopy(pattern)
                 dispat = Unflattening().visit(dispat)
                 print "pattern:  ", unparse(dispat)
                 disnew = deepcopy(new_ast)
                 disnew = Unflattening().visit(disnew)
                 print "after:    ", unparse(disnew)
                 print ""
             expr_ast = new_ast
             break
     # bitwise simplification: this is a ugly hack, should be
     # "generalized"
     expr_ast = Flattening(ast.BitXor).visit(expr_ast)
     expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast)
     expr_ast = Unflattening().visit(expr_ast)
     if DEBUG:
         print "after PM: "
         print unparse(expr_ast)
     return expr_ast
Beispiel #23
0
    def visit_Assign(self, node):
        'Simplify value of assignment and update context'

        # use EvalPattern to replace known variables
        node.value = pattern_matcher.EvalPattern(self.context).visit(
            node.value)
        old_value = deepcopy(node.value)
        old_value = asttools.LevelOperators().visit(old_value)
        node.value = self.simplify(node.value, self.nbits)
        copyvalue = deepcopy(node.value)
        copyvalue = asttools.LevelOperators().visit(copyvalue)
        # simplify until fixpoint is reached
        while not asttools.Comparator().visit(old_value, copyvalue):
            old_value = deepcopy(node.value)
            node.value = self.simplify(node.value, self.nbits)
            copyvalue = deepcopy(node.value)
            if len(unparse(copyvalue)) > len(unparse(old_value)):
                node.value = deepcopy(old_value)
                break
            copyvalue = asttools.LevelOperators().visit(copyvalue)
            old_value = asttools.LevelOperators().visit(old_value)
        for target in node.targets:
            self.context[target.id] = node.value
        return node
Beispiel #24
0
 def generic_test(self, input_ast, ref_ast, nbits):
     'Generic test for arithmetic simplification'
     output_ast = arithm_simpl.run(input_ast, nbits)
     self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))
Beispiel #25
0
 def generic_leveling(self, refstring_list, result):
     'Test matching of leveled AST and ref AST'
     for refstring in refstring_list:
         ref = ast.parse(refstring, mode="eval").body
         ref = asttools.LevelOperators().visit(ref)
         self.assertTrue(asttools.Comparator().visit(ref, result))
Beispiel #26
0
 def generic_test(self, expr, refstring, nbits=0):
     'Generic test for simplifier script'
     output_string = simplifier.simplify(expr, nbits)
     output = ast.parse(output_string)
     ref = ast.parse(refstring)
     self.assertTrue(asttools.Comparator().visit(output, ref))