def get_metrics(expr_ast): 'Return number of nodes and MBA alternation of DAG' input_ast = deepcopy(expr_ast) input_ast = Unflattening().visit(input_ast) input_ast = cse.apply_cse(input_ast)[1] visitor = DAGTranslator(input_ast) visitor.visit(input_ast) graph = visitor.graph graph.subgraph(list(visitor.variables), rank="same") return len(graph), visitor.alternation
def get_model(self, target, pattern): 'When target is constant and wildcards have no value yet' # pylint: disable=exec-used if target.n == 0: # zero is too permissive return False getwild = asttools.GetIdentifiers() getwild.visit(pattern) if getwild.functions: # not getting model for expr with functions return False wilds = getwild.variables # let's reduce the model to one wildcard for now # otherwise it adds a lot of checks... if len(wilds) > 1: return False wil = wilds.pop() if wil in self.wildcards: if not isinstance(self.wildcards[wil], ast.Num): return False folded = deepcopy(pattern) folded = Unflattening().visit(folded) EvalPattern(self.wildcards).visit(folded) folded = asttools.ConstFolding(folded, self.nbits).visit(folded) return folded.n == target.n else: exec("%s = z3.BitVec('%s', %d)" % (wil, wil, self.nbits)) eval_pattern = deepcopy(pattern) eval_pattern = Unflattening().visit(eval_pattern) ast.fix_missing_locations(eval_pattern) code = compile(ast.Expression(eval_pattern), '<string>', mode='eval') sol = z3.Solver() sol.add(target.n == eval(code)) if sol.check().r == 1: model = sol.model() for inst in model.decls(): self.wildcards[str(inst)] = ast.Num(int(model[inst].as_long())) return True return False
def test_unflattening(self): 'Tests to see if unflattening is correct' tests = [("x + (3 + y)", "3 + (y + x)"), ("x*(2*z)", "2*(z*x)"), ("x + (y + (z*(5*var)))", "y + (5*(var*z) + x)")] for test, ref in tests: ref_ast = ast.parse(ref) ast_test = ast.parse(test) Flattening().visit(ast_test) Unflattening().visit(ast_test) self.assertTrue(Comparator().visit(ast_test, ref_ast)) self.assertFalse('BoolOp' in astunparse.unparse(ast_test))
def check_eq_z3(self, target, pattern): 'Check equivalence with z3' # pylint: disable=exec-used getid = asttools.GetIdentifiers() getid.visit(target) if getid.functions: # not checking exprs with functions for now, because Z3 # does not seem to support function declaration with # arbitrary number of arguments return False for var in self.variables: exec("%s = z3.BitVec('%s', %d)" % (var, var, self.nbits)) target_ast = deepcopy(target) target_ast = Unflattening().visit(target_ast) ast.fix_missing_locations(target_ast) code1 = compile(ast.Expression(target_ast), '<string>', mode='eval') eval_pattern = deepcopy(pattern) EvalPattern(self.wildcards).visit(eval_pattern) eval_pattern = Unflattening().visit(eval_pattern) ast.fix_missing_locations(eval_pattern) getid.reset() getid.visit(eval_pattern) if getid.functions: # same reason as before, not using Z3 if there are # functions return False gvar = asttools.GetIdentifiers() gvar.visit(eval_pattern) if any(var.isupper() for var in gvar.variables): # do not check if all patterns have not been replaced return False code2 = compile(ast.Expression(eval_pattern), '<string>', mode='eval') sol = z3.Solver() if isinstance(eval(code1), int) and eval(code1) == 0: # cases where target == 0 are too permissive return False sol.add(eval(code1) != eval(code2)) return sol.check().r == -1
def simplify(self, expr_ast, nbits): 'Apply pattern matching and arithmetic simplification' expr_ast = arithm_simpl.run(expr_ast, nbits) expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast) if DEBUG: print "arithm simpl: " print unparse(expr_ast) if DEBUG: print "before matching: " print unparse(expr_ast) expr_ast = all_preprocessings(expr_ast, self.nbits) # only flattening ADD nodes because of traditionnal MBA patterns expr_ast = Flattening(ast.Add).visit(expr_ast) for pattern, repl in self.patterns: rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl) new_ast = rep.visit(deepcopy(expr_ast)) if not asttools.Comparator().visit(new_ast, expr_ast): if DEBUG: print "replaced! " dispat = deepcopy(pattern) dispat = Unflattening().visit(dispat) print "pattern: ", unparse(dispat) disnew = deepcopy(new_ast) disnew = Unflattening().visit(disnew) print "after: ", unparse(disnew) print "" expr_ast = new_ast break # bitwise simplification: this is a ugly hack, should be # "generalized" expr_ast = Flattening(ast.BitXor).visit(expr_ast) expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast) expr_ast = Unflattening().visit(expr_ast) if DEBUG: print "after PM: " print unparse(expr_ast) return expr_ast
def visit_BoolOp(self, node): 'Check if BoolOp is exaclty matching or contain pattern' if isinstance(self.patt_ast, ast.BoolOp): if len(node.values) == len(self.patt_ast.values): return self.basic_visit(node) elif len(node.values) > len(self.patt_ast.values): # associativity n to m for combi in itertools.combinations(node.values, len(self.patt_ast.values)): rest = [elem for elem in node.values if elem not in combi] testnode = ast.BoolOp(node.op, list(combi)) pat = PatternMatcher(testnode, self.nbits) matched = pat.visit(testnode, self.patt_ast) if matched: new = EvalPattern(pat.wildcards).visit(self.rep_ast) new = ast.BoolOp(node.op, [new] + rest) new = Unflattening().visit(new) return new return self.generic_visit(node) if isinstance(self.patt_ast, ast.BinOp): if type(node.op) != type(self.patt_ast.op): return self.generic_visit(node) op = node.op for combi in itertools.combinations(node.values, 2): rest = [elem for elem in node.values if elem not in combi] testnode = ast.BinOp(combi[0], op, combi[1]) pat = PatternMatcher(testnode, self.nbits) matched = pat.visit(testnode, self.patt_ast) if matched: new_node = EvalPattern(pat.wildcards).visit(self.rep_ast) new_node = ast.BoolOp(op, [new_node] + rest) new_node = Unflattening().visit(new_node) return new_node return self.generic_visit(node)
def visit_BoolOp(self, node): 'A custom BoolOp can be used in flattened AST' if type(node.op) not in (ast.Add, ast.Mult, ast.BitXor, ast.BitAnd, ast.BitOr): return self.generic_visit(node) # get constant parts of node: list_cste = [ child for child in node.values if isinstance(child, ast.Num) ] if len(list_cste) < 2: return self.generic_visit(node) rest_values = [n for n in node.values if n not in list_cste] fake_node = Unflattening().visit(ast.BoolOp(node.op, list_cste)) fake_node = ast.Expression(fake_node) ast.fix_missing_locations(fake_node) code = compile(fake_node, '<constant folding>', 'eval') obj_env = globals().copy() exec code in obj_env value = eval(code, obj_env) new_node = ast.Num(value) rest_values.append(new_node) return ast.BoolOp(node.op, rest_values)
return new_node return self.generic_visit(node) def replace(target_str, pattern_str, replacement_str): 'Apply pre-processing and replace' target_ast = ast.parse(target_str, mode="eval").body target_ast = pre_processing.all_preprocessings(target_ast) target_ast = Flattening(ast.Add).visit(target_ast) patt_ast = ast.parse(pattern_str, mode="eval").body patt_ast = pre_processing.all_preprocessings(patt_ast) patt_ast = Flattening(ast.Add).visit(patt_ast) rep_ast = ast.parse(replacement_str) rep = PatternReplacement(patt_ast, target_ast, rep_ast) return rep.visit(target_ast) # Used for debug purposes: if __name__ == '__main__': # pylint: disable=invalid-name patt_string = "A + B - (A | B)" test = "f(g(x + x) + 3 + 4)" repl = "A & B" print match(test, patt_string) print "-"*80 out = replace(test, patt_string, repl) print ast.dump(out) out = Unflattening().visit(out) print astunparse.unparse(out)