示例#1
0
def get_metrics(expr_ast):
    'Return number of nodes and MBA alternation of DAG'
    input_ast = deepcopy(expr_ast)
    input_ast = Unflattening().visit(input_ast)
    input_ast = cse.apply_cse(input_ast)[1]
    visitor = DAGTranslator(input_ast)
    visitor.visit(input_ast)
    graph = visitor.graph
    graph.subgraph(list(visitor.variables), rank="same")
    return len(graph), visitor.alternation
示例#2
0
    def get_model(self, target, pattern):
        'When target is constant and wildcards have no value yet'
        # pylint: disable=exec-used
        if target.n == 0:
            # zero is too permissive
            return False
        getwild = asttools.GetIdentifiers()
        getwild.visit(pattern)
        if getwild.functions:
            # not getting model for expr with functions
            return False
        wilds = getwild.variables
        # let's reduce the model to one wildcard for now
        # otherwise it adds a lot of checks...
        if len(wilds) > 1:
            return False

        wil = wilds.pop()
        if wil in self.wildcards:
            if not isinstance(self.wildcards[wil], ast.Num):
                return False
            folded = deepcopy(pattern)
            folded = Unflattening().visit(folded)
            EvalPattern(self.wildcards).visit(folded)
            folded = asttools.ConstFolding(folded, self.nbits).visit(folded)
            return folded.n == target.n
        else:
            exec("%s = z3.BitVec('%s', %d)" % (wil, wil, self.nbits))
        eval_pattern = deepcopy(pattern)
        eval_pattern = Unflattening().visit(eval_pattern)
        ast.fix_missing_locations(eval_pattern)
        code = compile(ast.Expression(eval_pattern), '<string>', mode='eval')
        sol = z3.Solver()
        sol.add(target.n == eval(code))
        if sol.check().r == 1:
            model = sol.model()
            for inst in model.decls():
                self.wildcards[str(inst)] = ast.Num(int(model[inst].as_long()))
            return True
        return False
示例#3
0
    def test_unflattening(self):
        'Tests to see if unflattening is correct'

        tests = [("x + (3 + y)", "3 + (y + x)"), ("x*(2*z)", "2*(z*x)"),
                 ("x + (y + (z*(5*var)))", "y + (5*(var*z) + x)")]

        for test, ref in tests:
            ref_ast = ast.parse(ref)
            ast_test = ast.parse(test)
            Flattening().visit(ast_test)
            Unflattening().visit(ast_test)
            self.assertTrue(Comparator().visit(ast_test, ref_ast))
            self.assertFalse('BoolOp' in astunparse.unparse(ast_test))
示例#4
0
 def check_eq_z3(self, target, pattern):
     'Check equivalence with z3'
     # pylint: disable=exec-used
     getid = asttools.GetIdentifiers()
     getid.visit(target)
     if getid.functions:
         # not checking exprs with functions for now, because Z3
         # does not seem to support function declaration with
         # arbitrary number of arguments
         return False
     for var in self.variables:
         exec("%s = z3.BitVec('%s', %d)" % (var, var, self.nbits))
     target_ast = deepcopy(target)
     target_ast = Unflattening().visit(target_ast)
     ast.fix_missing_locations(target_ast)
     code1 = compile(ast.Expression(target_ast), '<string>', mode='eval')
     eval_pattern = deepcopy(pattern)
     EvalPattern(self.wildcards).visit(eval_pattern)
     eval_pattern = Unflattening().visit(eval_pattern)
     ast.fix_missing_locations(eval_pattern)
     getid.reset()
     getid.visit(eval_pattern)
     if getid.functions:
         # same reason as before, not using Z3 if there are
         # functions
         return False
     gvar = asttools.GetIdentifiers()
     gvar.visit(eval_pattern)
     if any(var.isupper() for var in gvar.variables):
         # do not check if all patterns have not been replaced
         return False
     code2 = compile(ast.Expression(eval_pattern), '<string>', mode='eval')
     sol = z3.Solver()
     if isinstance(eval(code1), int) and eval(code1) == 0:
         # cases where target == 0 are too permissive
         return False
     sol.add(eval(code1) != eval(code2))
     return sol.check().r == -1
示例#5
0
 def simplify(self, expr_ast, nbits):
     'Apply pattern matching and arithmetic simplification'
     expr_ast = arithm_simpl.run(expr_ast, nbits)
     expr_ast = asttools.GetConstMod(self.nbits).visit(expr_ast)
     if DEBUG:
         print "arithm simpl: "
         print unparse(expr_ast)
     if DEBUG:
         print "before matching: "
         print unparse(expr_ast)
     expr_ast = all_preprocessings(expr_ast, self.nbits)
     # only flattening ADD nodes because of traditionnal MBA patterns
     expr_ast = Flattening(ast.Add).visit(expr_ast)
     for pattern, repl in self.patterns:
         rep = pattern_matcher.PatternReplacement(pattern, expr_ast, repl)
         new_ast = rep.visit(deepcopy(expr_ast))
         if not asttools.Comparator().visit(new_ast, expr_ast):
             if DEBUG:
                 print "replaced! "
                 dispat = deepcopy(pattern)
                 dispat = Unflattening().visit(dispat)
                 print "pattern:  ", unparse(dispat)
                 disnew = deepcopy(new_ast)
                 disnew = Unflattening().visit(disnew)
                 print "after:    ", unparse(disnew)
                 print ""
             expr_ast = new_ast
             break
     # bitwise simplification: this is a ugly hack, should be
     # "generalized"
     expr_ast = Flattening(ast.BitXor).visit(expr_ast)
     expr_ast = asttools.ConstFolding(expr_ast, self.nbits).visit(expr_ast)
     expr_ast = Unflattening().visit(expr_ast)
     if DEBUG:
         print "after PM: "
         print unparse(expr_ast)
     return expr_ast
示例#6
0
    def visit_BoolOp(self, node):
        'Check if BoolOp is exaclty matching or contain pattern'

        if isinstance(self.patt_ast, ast.BoolOp):
            if len(node.values) == len(self.patt_ast.values):
                return self.basic_visit(node)
            elif len(node.values) > len(self.patt_ast.values):
                # associativity n to m
                for combi in itertools.combinations(node.values,
                                                    len(self.patt_ast.values)):
                    rest = [elem for elem in node.values if elem not in combi]
                    testnode = ast.BoolOp(node.op, list(combi))
                    pat = PatternMatcher(testnode, self.nbits)
                    matched = pat.visit(testnode, self.patt_ast)
                    if matched:
                        new = EvalPattern(pat.wildcards).visit(self.rep_ast)
                        new = ast.BoolOp(node.op, [new] + rest)
                        new = Unflattening().visit(new)
                        return new
            return self.generic_visit(node)

        if isinstance(self.patt_ast, ast.BinOp):
            if type(node.op) != type(self.patt_ast.op):
                return self.generic_visit(node)
            op = node.op
            for combi in itertools.combinations(node.values, 2):
                rest = [elem for elem in node.values if elem not in combi]
                testnode = ast.BinOp(combi[0], op, combi[1])
                pat = PatternMatcher(testnode, self.nbits)
                matched = pat.visit(testnode, self.patt_ast)
                if matched:
                    new_node = EvalPattern(pat.wildcards).visit(self.rep_ast)
                    new_node = ast.BoolOp(op, [new_node] + rest)
                    new_node = Unflattening().visit(new_node)
                    return new_node
        return self.generic_visit(node)
示例#7
0
    def visit_BoolOp(self, node):
        'A custom BoolOp can be used in flattened AST'
        if type(node.op) not in (ast.Add, ast.Mult, ast.BitXor, ast.BitAnd,
                                 ast.BitOr):
            return self.generic_visit(node)
        # get constant parts of node:
        list_cste = [
            child for child in node.values if isinstance(child, ast.Num)
        ]
        if len(list_cste) < 2:
            return self.generic_visit(node)
        rest_values = [n for n in node.values if n not in list_cste]
        fake_node = Unflattening().visit(ast.BoolOp(node.op, list_cste))
        fake_node = ast.Expression(fake_node)
        ast.fix_missing_locations(fake_node)
        code = compile(fake_node, '<constant folding>', 'eval')
        obj_env = globals().copy()
        exec code in obj_env
        value = eval(code, obj_env)

        new_node = ast.Num(value)
        rest_values.append(new_node)
        return ast.BoolOp(node.op, rest_values)
示例#8
0
                    return new_node
        return self.generic_visit(node)


def replace(target_str, pattern_str, replacement_str):
    'Apply pre-processing and replace'
    target_ast = ast.parse(target_str, mode="eval").body
    target_ast = pre_processing.all_preprocessings(target_ast)
    target_ast = Flattening(ast.Add).visit(target_ast)
    patt_ast = ast.parse(pattern_str, mode="eval").body
    patt_ast = pre_processing.all_preprocessings(patt_ast)
    patt_ast = Flattening(ast.Add).visit(patt_ast)
    rep_ast = ast.parse(replacement_str)
    rep = PatternReplacement(patt_ast, target_ast, rep_ast)
    return rep.visit(target_ast)


# Used for debug purposes:
if __name__ == '__main__':
    # pylint: disable=invalid-name
    patt_string = "A + B - (A | B)"
    test = "f(g(x + x) + 3 + 4)"
    repl = "A & B"

    print match(test, patt_string)
    print "-"*80
    out = replace(test, patt_string, repl)
    print ast.dump(out)
    out = Unflattening().visit(out)
    print astunparse.unparse(out)