def test_root(self): 'Test with different types of roots' patt_ast = ast.parse("A + B", mode='eval') input_ast = ast.parse("x + y", mode='eval') ref_ast = ast.parse("89", mode='eval') rep_ast = ast.parse("89", mode='eval') rep = pattern_matcher.PatternReplacement(patt_ast, input_ast, rep_ast) input_ast = rep.visit(input_ast) self.assertTrue(asttools.Comparator().visit(input_ast, ref_ast))
def generic_test_replacement(self, tests, pattern, replacement): 'Generic test for a list of input/output' for input_string, refstring in tests: input_ast = ast.parse(input_string) ref_ast = ast.parse(refstring) patt_ast = ast.parse(pattern) rep_ast = ast.parse(replacement) rep = pattern_matcher.PatternReplacement(patt_ast, input_ast, rep_ast) input_ast = rep.visit(input_ast) self.assertTrue(asttools.Comparator().visit(input_ast, ref_ast))
def simplify(self, expr_ast, nbits): 'Apply pattern matching and arithmetic simplification' if DEBUG: print "before: " print unparse(expr_ast) print "" expr_ast = all_target_preprocessings(expr_ast, self.nbits) expr_ast = asttools.LevelOperators(ast.Add).visit(expr_ast) for pattern, repl in self.patterns: rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl) new_ast = rep.visit(deepcopy(expr_ast)) if DEBUG: if not asttools.Comparator().visit(new_ast, expr_ast): print "replaced! " expr_debug = deepcopy(expr_ast) expr_debug = asttools.Unleveling().visit(expr_debug) print unparse(expr_debug) new_debug = deepcopy(new_ast) new_debug = asttools.Unleveling().visit(new_debug) print unparse(new_debug) print "before: ", ast.dump(expr_ast) print "pattern: ", ast.dump(pattern) patt_debug = asttools.Unleveling().visit(deepcopy(pattern)) print unparse(patt_debug) print "" print "" print "after: ", ast.dump(new_ast) print "" expr_ast = new_ast # bitwise simplification: this is a ugly hack, should be # "generalized" expr_ast = asttools.LevelOperators(ast.BitXor).visit(expr_ast) expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast) expr_ast = asttools.Unleveling().visit(expr_ast) if DEBUG: print "after PM: " print unparse(expr_ast) print "" expr_ast = arithm_simpl.run(expr_ast, nbits) expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast) if DEBUG: print "arithm simpl: " print unparse(expr_ast) print "" print "-" * 80 return expr_ast
def test_leveled(self): 'Test on leveled ast' patt_string = "A + 2*B + 3*C" rep_string = "A" test_pos = "3*z + x + 2*y" ref_ast = ast.parse("x", mode='eval').body output_ast = pattern_matcher.replace(test_pos, patt_string, rep_string) self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast)) # only ADD nodes are leveled right now, this is for code # coverage test_neg = ast.parse("3*z ^ x ^ 2*y") test_neg = asttools.LevelOperators().visit(test_neg) patt_ast = ast.parse("A + 3*z") patt_ast = asttools.LevelOperators().visit(patt_ast) rep_ast = ast.parse(rep_string) ref_ast = ast.parse("3*z ^ x ^ 2*y") ref_ast = asttools.LevelOperators().visit(ref_ast) rep = pattern_matcher.PatternReplacement(patt_ast, test_neg, rep_ast) output_ast = rep.visit(test_neg) self.assertTrue(asttools.Comparator().visit(output_ast, ref_ast))
def simplify(self, expr_ast, nbits): 'Apply pattern matching and arithmetic simplification' expr_ast = arithm_simpl.run(expr_ast, nbits) expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast) if DEBUG: print "arithm simpl: " print unparse(expr_ast) if DEBUG: print "before matching: " print unparse(expr_ast) expr_ast = all_preprocessings(expr_ast, self.nbits) # only flattening ADD nodes because of traditionnal MBA patterns expr_ast = Flattening(ast.Add).visit(expr_ast) for pattern, repl in self.patterns: rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl) new_ast = rep.visit(deepcopy(expr_ast)) if not asttools.Comparator().visit(new_ast, expr_ast): if DEBUG: print "replaced! " dispat = deepcopy(pattern) dispat = Unflattening().visit(dispat) print "pattern: ", unparse(dispat) disnew = deepcopy(new_ast) disnew = Unflattening().visit(disnew) print "after: ", unparse(disnew) print "" expr_ast = new_ast break # bitwise simplification: this is a ugly hack, should be # "generalized" expr_ast = Flattening(ast.BitXor).visit(expr_ast) expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast) expr_ast = Unflattening().visit(expr_ast) if DEBUG: print "after PM: " print unparse(expr_ast) return expr_ast