def testWP(self):
     """ 
     """
     for (expected, stmt, post, typeEnv) in self.testWPCases:
         stmtAst = parseStmt(stmt)
         got = wp_stmt(stmtAst, post, typeEnv)
         assert z3.eq(
             got, expected
         ), "Expected wp {} got {} from pred {} over stmt {}".format(
             expected, got, post, stmt)
Ejemplo n.º 2
0
 def insert(e):
     found = False
     for t in eq_terms:
         if z3.eq(t, e):
             found = True
             break
     if not found:
         eq_terms.append(e)
         return True
     return False
 def testSP(self):
     """ 
     """
     for (pre, stmt, expected, typeEnv) in self.testSPCases:
         stmtAst = parseStmt(stmt)
         got = sp_stmt(stmtAst, pre, typeEnv)
         assert z3.eq(
             got, expected
         ), "Expected sp {} got {} from pred {} over stmt {}".format(
             expected, got, pre, stmt)
Ejemplo n.º 4
0
def get_function_from_constraints(contract, constraints):
    # Todo first we could search for constraints that could be a restriction to the function hash
    # Todo a calldata length > 4 constraint could be searched for to
    for function in contract.functions:
        function_constraint = Extract(
            255, 224, BitVec("calldata_" + contract.name + "[0]",
                             256)) == int(function.hash, 16)
        for constraint in constraints:
            if eq(constraint, function_constraint):
                return function
    return None
Ejemplo n.º 5
0
def testZ3Distinctness():
    '''
    This test is simply a playground to explore
    how z3 handles distinctness and equality checking.
    '''
    s = Solver()
    x, y = Consts('x y', language.PointSort)
    s.add(Distinct(x))   
    s.add(Distinct(y))
    print s
#     print s.add(Not(eq(x,y)))
#     print eq(simplify(x),simplify(y))
#     s.add(eq(x,y))
    s.add(Not(eq(x,y)))
    print s.check()
    print s
Ejemplo n.º 6
0
 def equivalent(self, shape: "Shape") -> bool:
     """Returns true iff the given shape is equivalent to this shape."""
     if len(self.offsets_with_payloads) != len(shape.offsets_with_payloads):
         return False
     for (v1, p1), (v2, p2) in zip(
             self.offsets_with_payloads,  # type: ignore[arg-type]
             shape.offsets_with_payloads  # type: ignore[arg-type]
     ):
         if v1 != v2:
             return False
         if isinstance(p1, ExprRef) and isinstance(p2, ExprRef):
             if not eq(p1, p2):
                 return False
         elif p1 != p2:
             return False
     return True
 def grow(self):
     self.mss = []
     self.mcs = []
     self.nmcs = []
     self.mcs_explain = {}
     self.unknown = self.soft_vars
     self.update_unknown()
     cores = []
     while len(self.unknown) > 0:
         x = self.unknown.pop()
         is_sat = self.s.check(self.mss + [x] + self.nmcs)
         if is_sat == z3.sat:
             self.mss.append(x)
             self.update_unknown()
         elif is_sat == z3.unsat:
             core = self.s.unsat_core()
             core = self.resolve_core(core)
             self.mcs_explain[z3.Not(x)] = {
                 y
                 for y in core if not z3.eq(x, y)
             }
             self.mcs.append(x)
             self.nmcs.append(z3.Not(x))
             cores += [core]
         else:
             print("solver returned %s" % is_sat)
             exit()
     mss = [x for x in self.orig_soft_vars if z3.is_true(self.model[x])]
     mcs = [x for x in self.orig_soft_vars if not z3.is_true(self.model[x])]
     self.s.add(z3.Or(mcs))
     core_literals = set([])
     cores.sort(key=lambda element: len(element))
     for core in cores:
         if len(core & core_literals) == 0:
             self.relax_core(core)
             core_literals |= core
     return mss
Ejemplo n.º 8
0
def outline_loop (name, decls, rules, self_loops, preds, succs):
    assert name in self_loops
    loop_rules = self_loops [name]
    decl = decls [name]
    arg_indices = find_arg_indices (decl, loop_rules)
    print 'Indices for procedure args:', arg_indices
    arg_sorts = [decl.domain (i) for i in arg_indices]

    # make new procedure
    proc_sorts = arg_sorts + arg_sorts + [decl.range ()]
    proc_name = decl.name () + '_proc'
    proc_decl = z3.Function (proc_name, *proc_sorts)

    # update decls
    decls [proc_name] = proc_decl
    del decls [name]

    # factor the self-loop out into a procedure;
    # can have multiple self-loop rules because of the transformation

    # Self loop: P(x) /\ t(x,x') -> P(x')
    # Transformed rule: t(x,x') /\ P_proc(x',x!next) -> P_proc(x,x!next)
    for rule in loop_rules:
        curr_args = [rule.get_body_inst ().arg (i) for i in arg_indices]
        next_args = [z3.Const (arg.decl ().name () + '!next', arg.sort ()) \
                     for arg in curr_args]
        temp_args = [rule.get_head ().arg (i) for i in arg_indices]
        rule_exp = z3.ForAll (rule.qvars + next_args,
                              z3.Implies (z3.And (proc_decl (*(temp_args + next_args)),
                                                  rule.get_trans ()),
                                          proc_decl (*(curr_args + next_args))))
        print 'Removing loop rule', repr (rule)
        rules.remove (rule)
        rules.append (Rule (rule_exp))

    # Exit rule: P(x) /\ t(x,y) -> Q(y)
    # Exit rule of P_proc: t(x,y) -> P_proc (x,x)
    for rule in succs [name]:
        curr_args = [rule.get_body_inst ().arg (i) for i in arg_indices]
        next_args = curr_args
        rule_exp = z3.ForAll (rule.qvars,
                              z3.Implies (rule.get_trans (),
                                          proc_decl (*(curr_args + next_args))))
        rules.append (Rule (rule_exp))

    # wire incoming and outgoing rules, via P, together
    # in-rule: Q(y) /\ t(y,x) -> P(x)
    # out-rule: P(x) /\ t(x,z) -> R(z)
    # new rule: Q(y) /\ t(y,x) /\ P_proc (x,x') /\ t(x',z) -> R(z)
    for in_rule in preds [name]:
        #print 'In rule:', repr (in_rule)
        if in_rule.has_body_pred ():
            from_inst = in_rule.get_body_inst ()
        else:
            from_inst = z3.BoolVal (True)
        in_trans = in_rule.get_trans ()
        curr_args = [in_rule.get_head ().arg (i) for i in arg_indices]

        for out_rule in succs [name]:
            #print 'Out rule:', repr (out_rule)

            # version all out_rule.qvars; treat body args and the rest separately
            sub = list ()
            out_vars = list ()
            # body args
            for i in range (decl.arity ()):
                v = out_rule.get_body_inst ().arg (i)
                if i in arg_indices:
                    # create new variable
                    v_new = z3.Const (v.decl ().name () + '_temp', v.sort ())
                    out_vars.append (v_new)
                else:
                    # use corresponding arg from in_rule; it is unchanged by the loop
                    v_new = in_rule.get_head ().arg (i)
                sub.append ( (v,v_new) )
            # rest
            for v in out_rule.qvars:
                is_body_arg = False
                for i in range (decl.arity ()):
                    if z3.eq (v, out_rule.get_body_inst ().arg (i)):
                        is_body_arg = True
                        break
                if is_body_arg: continue

                v_temp = z3.Const (v.decl ().name () + '_temp', v.sort ())
                sub.append ( (v,v_temp) )
                out_vars.append (v_temp)

            next_args = [z3.substitute (out_rule.get_body_inst ().arg (i), *sub)\
                         for i in arg_indices]
            out_trans = z3.substitute (out_rule.get_trans (), *sub)
            to_inst = z3.substitute (out_rule.get_head (), *sub)

            # make new rule via proc_decl
            rule_exp = z3.ForAll (in_rule.qvars + out_vars,
                                  z3.Implies (z3.And (from_inst,
                                                      in_trans,
                                                      proc_decl (*(curr_args + next_args)),
                                                      out_trans),
                                              to_inst))
            rule = Rule (rule_exp)
            if rule.is_self_loop ():
                insert_in_dict (self_loops, rule.get_head_pred_name (), rule)
            else:
                insert_in_dict (preds, rule.get_head_pred_name (), rule)
                if rule.has_body_pred ():
                    insert_in_dict (succs, rule.get_body_pred_name (), rule)
            rules.append (rule)

    # remove incoming and outgoing rules, via P
    for in_rule in preds [name]:
        print 'Removing in rule', repr (in_rule)
        rules.remove (in_rule)
        if in_rule.has_body_pred ():
            remove_from_dict (succs, in_rule.get_body_pred_name (), in_rule)
    del preds [name]
    for out_rule in succs [name]:
        print 'Removing out rule', repr (out_rule)
        rules.remove (out_rule)
        remove_from_dict (preds, out_rule.get_head_pred_name (), out_rule)
    del succs [name]
Ejemplo n.º 9
0
def is_data_const(const):
    assert (z3.is_const(const))
    return z3.eq(const.sort(), data_sort)
Ejemplo n.º 10
0
    def _update_all_ref_tracker(self, instruction: Instruction):
        # update all the references
        for ref in self.call_result_references + self.timestamp_references + self.reentrancy_references:
            ref.update(instruction, self._stack,
                       self.immutable_storage_references)
        # check if there are new references
        if instruction.opcode in [
                'CALL', "STATICCALL", "DELEGATECALL", "CALLCODE"
        ]:
            h = len(self._stack) - instruction.input_amount
            if len(self.reentrancy_references) == 0:
                # 如果之前没有任何Storage被读取,也就是为创建任何ReentrancyTracker,那么当前的这个Call就是有reentrancy bug的
                tmp = ReentrancyTracker(instruction.addr, h, -1)
                tmp.buggy = True
                self.reentrancy_references.append(tmp)
            # 判断Call的目的地址是否是可变的(mutable)
            for ref in self.immutable_storage_references:
                if ref.contains(len(self._stack) - 2):
                    break
            else:
                # 当前的目的地址包含在某一个mutable storage reference中
                # 只有当目的地址不是一个确定值,也就是说不可靠的时候
                # new call result reference is generated
                call_ref = CallResultTracker(instruction.addr, h)
                self.call_result_references.append(call_ref)
        elif instruction.opcode == "TIMESTAMP":
            # new timestamp reference is generated here
            ref = TimestampDepTracker(instruction.addr, len(self._stack))
            self.timestamp_references.append(ref)
        elif instruction.opcode == "SLOAD":
            storage_addr = self._stack[-1]
            h = len(self._stack) - instruction.input_amount
            # 检查是否需要新建MutableStorageTracker
            if not utils.in_list(self.mutable_storage_addresses, storage_addr):
                # 是不可变的(immutable)
                for ref in self.immutable_storage_references:
                    if utils.eq(ref.storage_addr, storage_addr):
                        ref.new(h)
                        ref.new_born = True
                        break
                else:
                    ref = ImmutableStorageTracker(instruction.addr, h,
                                                  storage_addr,
                                                  self._storage[storage_addr])
                    self.immutable_storage_references.append(ref)
            # 检查是否需要新建ReentrancyTracker
            for r in self.reentrancy_references:
                # check if there already exists the same reference
                try:
                    if utils.is_symbol(storage_addr) and utils.is_symbol(
                            r.storage_addr) and eq(
                                simplify(r.storage_addr),
                                simplify(storage_addr)) or not utils.is_symbol(
                                    storage_addr) and not utils.is_symbol(
                                        r.storage_addr
                                    ) and r.storage_addr == storage_addr:
                        r.new(h)
                        break
                except Exception as e:
                    print(e)
            else:
                ref = ReentrancyTracker(instruction.addr, h, storage_addr)
                self.reentrancy_references.append(ref)

        # 更新mutable storage reference
        for ref in self.immutable_storage_references:
            if ref.new_born:
                ref.new_born = False
            else:
                ref.update(instruction, self._stack, None)
Ejemplo n.º 11
0
 def exe_with_path_condition(self,
                             instruction: Instruction,
                             path_condition: list = []) -> PcPointer:
     self._update_all_ref_tracker(instruction)
     if instruction.opcode == "SSTORE":
         # save the value of every referred storage variable before SSTORE
         bak = {}
         for ref in self.reentrancy_references:
             bak[ref] = self._storage[ref.storage_addr]
     if instruction.opcode == "SSTORE" and self.pre_process:
         op0 = self._stack[-1]
         # 在可信条件下进行修改的Storage变量仍然是可信的(immutable)
         caller = z3.Int("Is")
         solver = z3.Solver()
         solver.add(path_condition)
         if "sat" == str(solver.check()):
             for storage_addr, storage_value in self._storage.get_storage(
             ).items():
                 if not utils.in_list(self.mutable_storage_addresses,
                                      storage_addr):
                     # solver.add(caller & 0xffffffffffffffffffff != storage_value & 0xffffffffffffffffffff)
                     mask = 0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff
                     if utils.is_symbol(storage_value):
                         # solver.add(z3.Int(str("Is") + "&" + str(mask)) != z3.Int(str(storage_value) + "&" + str(mask)))
                         # solver.add(z3.Int(str(mask) + "&" + str("Is")) != z3.Int(str(storage_value) + "&" + str(mask)))
                         # solver.add(z3.Int(str("Is") + "&" + str(mask)) != z3.Int(str(mask) + "&" + str(storage_value)))
                         # solver.add(z3.Int(str(mask) + "&" + str("Is")) != z3.Int(str(mask) + "&" + str(storage_value)))
                         # solver.add(z3.Int(str("Is")) != z3.Int(str(storage_value)))
                         solver.add(
                             z3.Int(str("Is")) != z3.Int(str(storage_value))
                         )
                     else:
                         # solver.add(z3.Int(str("Is") + "&" + str(mask)) != storage_value & mask)
                         # solver.add(z3.Int(str(mask) + "&" + str("Is")) != storage_value & mask)
                         solver.add(
                             z3.Int(str("Is")) != storage_value & mask)
             if "sat" == str(solver.check()):
                 # caller不为任意一个可信storage变量的时候仍然可能进行SSTORE,则说明被修改的storage变量是不可靠的
                 if not utils.in_list(self.mutable_storage_addresses, op0):
                     self.mutable_storage_addresses.append(op0)
     pc_pointer = super().exe(instruction)
     if instruction.opcode == "SSTORE":
         # check if any referred storage variable is changed after SSTORE
         # if len(bak) != len(self._storage):
         #     # 如果新增了Storage变量,那么一定是做修改了
         #     ref.storage_changed = True
         #     if ref.after_used_in_condition:
         #         ref.changed_after_condition = True
         #     if not ref.after_call:
         #         ref.changed_before_call = True
         # else:
         # 如果Storage变量的个数没变,那么就检查每一个变量的值有没有改变
         for ref, value in bak.items():
             if utils.is_symbol(value) is not utils.is_symbol(self._storage[ref.storage_addr]) or \
                     utils.is_symbol(value) and not eq(simplify(value),
                                                       simplify(self._storage[ref.storage_addr])) or \
                     not utils.is_symbol(value) and value != self._storage[ref.storage_addr]:
                 ref.storage_changed = True
                 if ref.after_used_in_condition:
                     ref.changed_after_condition = True
                 if not ref.after_call:
                     ref.changed_before_call = True
     return pc_pointer
 def notEqual(v, vars):
     return z3.And([v != i for i in vars if not z3.eq(v,i)])
Ejemplo n.º 13
0
def is_storage_primitive(storage):
    if storage:
        for index, content in storage._storage.items():
            if isinstance(content, int) or not eq(content, BitVec("storage[" + str(index) + "]", 256)):
                return False
    return True
Ejemplo n.º 14
0
 def key_fun(root):
     f = lambda pred: z3.eq(pred.root, root)
     pred = utils.find_first(f, preds)
     # Prepend branching factor to prioritize linear structs in lexicographic sort
     return (pred.struct.branching_factor, str(root))
 def notEqual(v, vars):
     return z3.And([v != i for i in vars if not z3.eq(v, i)])
Ejemplo n.º 16
0
    def substitute(self, other, expr):
        def update_term(t, args):
            n = len(args)
            # Need to pass an AstArray type into Z3_update_term, not a python list
            args_ast_arr = (z3.Ast * n)()
            for i in range(n):
                args_ast_arr[i] = args[i].as_ast()
            return _to_expr_ref(
                z3.Z3_update_term(t.ctx_ref(), t.as_ast(), n, args_ast_arr),
                t.ctx)

        cache = z3.AstMap(ctx=expr.ctx)

        for addr, state in self._expanded_global_state.items():
            other_state = other.initial_contract_state(addr)
            cache[state.storage] = other_state.storage
            cache[state.balance] = other_state.balance
            cache[state.nonce] = other_state.nonce

        fnsubs = []
        for k, v in self._cache.items():
            if isinstance(v, mem.Memory):
                # Uses of Memory objects will produce an expression containing
                # the underlying array object (_mem). If required the index
                # (_idx) will have been substituted with the actual indexing
                # expression at that point, so we do not need to consider it
                # here.
                cache[v._mem] = getattr(other, k)()._mem
            elif z3.is_app(v) and v.num_args() > 0:
                fnsubs.append((v, getattr(other, k)()))
            else:
                cache[v] = getattr(other, k)()

        todo = [expr]
        while todo:
            n = todo[-1]
            if n in cache:
                todo.pop()
            elif z3.is_var(n):
                cache[n] = n
                todo.pop()
            elif z3.is_app(n):
                new_args = []
                for i in range(n.num_args()):
                    arg = n.arg(i)
                    if arg not in cache:
                        todo.append(arg)
                    else:
                        new_args.append(cache[arg])
                # Only actually do the substitution if all the arguments have
                # already been processed
                if len(new_args) == n.num_args():
                    todo.pop()
                    fn = n.decl()
                    for oldfn, newfn in fnsubs:
                        if z3.eq(fn, oldfn):
                            new_fn = z3.substitute_vars(newfn, *new_args)
                            break
                    else:
                        # TODO only if new_args != old_args
                        if len(new_args) != fn.arity():
                            new_fn = update_term(n, new_args)
                        else:
                            new_fn = fn(*new_args)
                    cache[n] = new_fn
            else:
                assert z3.is_quantifier(n)
                # Not currently implemented as don't use quanitifers at the
                # moment
                raise NotImplementedError()
        return cache[expr]