예제 #1
0
    def check_twomult(self, target, pattern):
        'Check 2*... pattern that could be in another form'
        if isinstance(pattern.left, ast.Num) and pattern.left.n == 2:
            operand = pattern.right
        elif isinstance(pattern.right, ast.Num) and pattern.right.n == 2:
            operand = pattern.left
        else:
            return False

        # deal with case where wildcard operand and target are const values
        if isinstance(target, ast.Num) and isinstance(operand, ast.Name):
            conds = (operand.id in self.wildcards
                     and isinstance(self.wildcards[operand.id], ast.Num))
            if conds:
                eva = (self.wildcards[operand.id].n) * 2 % 2**(self.nbits)
                if eva == target.n:
                    return True
            else:
                if target.n % 2 == 0:
                    self.wildcards[operand.id] = ast.Num(target.n / 2)
                    return True
                return False

        # get all wildcards in operand and check if they have value
        getwild = asttools.GetIdentifiers()
        getwild.visit(operand)
        wilds = getwild.variables
        for wil in wilds:
            if wil not in self.wildcards:
                return False
        return self.check_eq_z3(target, pattern)
예제 #2
0
    def __init__(self, root, nbits=0):
        'Init different components of pattern matcher'

        super(PatternMatcher, self).__init__()

        # wildcards used in the pattern with their possible values
        self.wildcards = {}
        # wildcards <-> values that are known not to work
        self.no_solution = []

        # root node of expression
        if isinstance(root, ast.Module):
            self.root = root.body[0].value
        elif isinstance(root, ast.Expression):
            self.root = root.body
        else:
            self.root = root
        if not nbits:
            self.nbits = asttools.get_default_nbits(self.root)
        else:
            self.nbits = nbits

        # identifiers for z3 evaluation
        getid = asttools.GetIdentifiers()
        getid.visit(self.root)
        self.variables = getid.variables
        self.functions = getid.functions
예제 #3
0
    def test_xor36(self):
        'Test that CSE of the xor36 function is equivalent to original'
        # pylint: disable=exec-used
        pwd = os.path.dirname(os.path.realpath(__file__))
        input_file = open(os.path.join(pwd, 'xor36_flat'), 'r')
        input_string = input_file.read()
        input_ast = ast.parse(input_string)
        coderef = compile(ast.Expression(input_ast.body[0].value), '<string>',
                          'eval')
        jack = asttools.GetIdentifiers()
        jack.visit(input_ast)

        cse_string = cse.apply_cse(input_string)
        # get all assignment in one ast
        assigns = cse_string[:cse_string.rfind('\n')]
        cse_assign_ast = ast.parse(assigns, mode='exec')
        assign_code = compile(cse_assign_ast, '<string>', mode='exec')
        # get final expression in one ast
        result_string = cse_string.splitlines()[-1]
        result_ast = ast.Expression(ast.parse(result_string).body[0].value)
        result_code = compile(result_ast, '<string>', mode='eval')

        for var in list(jack.variables):
            exec("%s = z3.BitVec('%s', 8)" % (var, var))
        exec(assign_code)
        sol = z3.Solver()
        sol.add(eval(coderef) != eval(result_code))
        self.assertEqual(sol.check().r, -1)
예제 #4
0
 def generic_test(self, tests):
     'Generic test for GetIdentifiers class'
     geti = asttools.GetIdentifiers()
     for instring, varref, funref in tests:
         geti.reset()
         inast = ast.parse(instring)
         geti.visit(inast)
         self.assertEquals(geti.variables, varref)
         self.assertEquals(geti.functions, funref)
예제 #5
0
 def general_check(self, target, pattern):
     'General check, very time-consuming, not used at the moment'
     getwild = asttools.GetIdentifiers()
     getwild.visit(pattern)
     wilds = list(getwild.variables)
     if all(wil in self.wildcards for wil in wilds):
         eval_pattern = deepcopy(pattern)
         eval_pattern = EvalPattern(self.wildcards).visit(eval_pattern)
         return self.check_eq_z3(target, eval_pattern)
     return False
예제 #6
0
def run(expr_ast, nbits):
    'Apply sympy arithmetic simplifications to expression ast'

    # variables for sympy symbols
    getid = asttools.GetIdentifiers()
    getid.visit(expr_ast)
    variables = getid.variables
    functions = getid.functions

    original_type = type(expr_ast)
    # copying to avoid wierd pointer behaviour
    expr_ast = deepcopy(expr_ast)
    # converting expr_ast into an ast.Expression
    if not isinstance(expr_ast, ast.Expression):
        if isinstance(expr_ast, ast.Module):
            expr_ast = ast.Expression(expr_ast.body[0].value)
        elif isinstance(expr_ast, ast.Expr):
            expr_ast = ast.Expression(expr_ast.value)
        else:
            expr_ast = ast.Expression(expr_ast)

    for var in variables:
        exec("%s = sympy.Symbol('%s')" % (var, var))
    for fun in {"mxor", "mor", "mand", "mnot", "mrshift", "mlshift"}:
        exec("%s = sympy.Function('%s')" % (fun, fun))
    for fun in functions:
        exec("%s = sympy.Function('%s')" % (fun, fun))
    expr_ast = asttools.ReplaceBitwiseOp().visit(expr_ast)
    ast.fix_missing_locations(expr_ast)
    code = compile(expr_ast, '<test>', mode='eval')
    eval_expr = eval(code)
    try:
        expr_ast = ast.parse(str(eval_expr))
    except SyntaxError as ex:
        print ex
        exit(1)

    expr_ast = asttools.ReplaceBitwiseFunctions().visit(expr_ast)
    # sympy does not consider the number of bits
    expr_ast = asttools.GetConstMod(nbits).visit(expr_ast)

    # return original type
    if original_type == ast.Expression:
        expr_ast = ast.Expression(expr_ast.body[0].value)
    elif original_type == ast.Expr:
        expr_ast = expr_ast.body[0]
    elif original_type == ast.Module:
        return expr_ast
    else:
        expr_ast = expr_ast.body[0].value
    return expr_ast
예제 #7
0
 def check_eq_z3(self, target, pattern):
     'Check equivalence with z3'
     # pylint: disable=exec-used
     getid = asttools.GetIdentifiers()
     getid.visit(target)
     if getid.functions:
         # not checking exprs with functions for now, because Z3
         # does not seem to support function declaration with
         # arbitrary number of arguments
         return False
     for var in self.variables:
         exec("%s = z3.BitVec('%s', %d)" % (var, var, self.nbits))
     target_ast = deepcopy(target)
     target_ast = asttools.Unleveling().visit(target_ast)
     ast.fix_missing_locations(target_ast)
     code1 = compile(ast.Expression(target_ast), '<string>', mode='eval')
     eval_pattern = deepcopy(pattern)
     EvalPattern(self.wildcards).visit(eval_pattern)
     eval_pattern = asttools.Unleveling().visit(eval_pattern)
     ast.fix_missing_locations(eval_pattern)
     getid.reset()
     getid.visit(eval_pattern)
     if getid.functions:
         # same reason as before, not using Z3 if there are
         # functions
         return False
     gvar = asttools.GetIdentifiers()
     gvar.visit(eval_pattern)
     if any(var.isupper() for var in gvar.variables):
         # do not check if all patterns have not been replaced
         return False
     code2 = compile(ast.Expression(eval_pattern), '<string>', mode='eval')
     sol = z3.Solver()
     if isinstance(eval(code1), int) and eval(code1) == 0:
         # cases where target == 0 are too permissive
         return False
     sol.add(eval(code1) != eval(code2))
     return sol.check().r == -1
예제 #8
0
    def get_model(self, target, pattern):
        'When target is constant and wildcards have no value yet'
        # pylint: disable=exec-used
        if target.n == 0:
            # zero is too permissive
            return False
        getwild = asttools.GetIdentifiers()
        getwild.visit(pattern)
        if getwild.functions:
            # not getting model for expr with functions
            return False
        wilds = getwild.variables
        # let's reduce the model to one wildcard for now
        # otherwise it adds a lot of checks...
        if len(wilds) > 1:
            return False

        wil = wilds.pop()
        if wil in self.wildcards:
            if not isinstance(self.wildcards[wil], ast.Num):
                return False
            folded = deepcopy(pattern)
            folded = Unflattening().visit(folded)
            EvalPattern(self.wildcards).visit(folded)
            folded = asttools.ConstFolding(folded, self.nbits).visit(folded)
            return folded.n == target.n
        else:
            exec("%s = z3.BitVec('%s', %d)" % (wil, wil, self.nbits))
        eval_pattern = deepcopy(pattern)
        eval_pattern = Unflattening().visit(eval_pattern)
        ast.fix_missing_locations(eval_pattern)
        code = compile(ast.Expression(eval_pattern), '<string>', mode='eval')
        sol = z3.Solver()
        sol.add(target.n == eval(code))
        if sol.check().r == 1:
            model = sol.model()
            for inst in model.decls():
                self.wildcards[str(inst)] = ast.Num(int(model[inst].as_long()))
            return True
        return False