def testWP(self): """ """ for (expected, stmt, post, typeEnv) in self.testWPCases: stmtAst = parseStmt(stmt) got = wp_stmt(stmtAst, post, typeEnv) assert z3.eq( got, expected ), "Expected wp {} got {} from pred {} over stmt {}".format( expected, got, post, stmt)
def insert(e): found = False for t in eq_terms: if z3.eq(t, e): found = True break if not found: eq_terms.append(e) return True return False
def testSP(self): """ """ for (pre, stmt, expected, typeEnv) in self.testSPCases: stmtAst = parseStmt(stmt) got = sp_stmt(stmtAst, pre, typeEnv) assert z3.eq( got, expected ), "Expected sp {} got {} from pred {} over stmt {}".format( expected, got, pre, stmt)
def get_function_from_constraints(contract, constraints): # Todo first we could search for constraints that could be a restriction to the function hash # Todo a calldata length > 4 constraint could be searched for to for function in contract.functions: function_constraint = Extract( 255, 224, BitVec("calldata_" + contract.name + "[0]", 256)) == int(function.hash, 16) for constraint in constraints: if eq(constraint, function_constraint): return function return None
def testZ3Distinctness(): ''' This test is simply a playground to explore how z3 handles distinctness and equality checking. ''' s = Solver() x, y = Consts('x y', language.PointSort) s.add(Distinct(x)) s.add(Distinct(y)) print s # print s.add(Not(eq(x,y))) # print eq(simplify(x),simplify(y)) # s.add(eq(x,y)) s.add(Not(eq(x,y))) print s.check() print s
def equivalent(self, shape: "Shape") -> bool: """Returns true iff the given shape is equivalent to this shape.""" if len(self.offsets_with_payloads) != len(shape.offsets_with_payloads): return False for (v1, p1), (v2, p2) in zip( self.offsets_with_payloads, # type: ignore[arg-type] shape.offsets_with_payloads # type: ignore[arg-type] ): if v1 != v2: return False if isinstance(p1, ExprRef) and isinstance(p2, ExprRef): if not eq(p1, p2): return False elif p1 != p2: return False return True
def grow(self): self.mss = [] self.mcs = [] self.nmcs = [] self.mcs_explain = {} self.unknown = self.soft_vars self.update_unknown() cores = [] while len(self.unknown) > 0: x = self.unknown.pop() is_sat = self.s.check(self.mss + [x] + self.nmcs) if is_sat == z3.sat: self.mss.append(x) self.update_unknown() elif is_sat == z3.unsat: core = self.s.unsat_core() core = self.resolve_core(core) self.mcs_explain[z3.Not(x)] = { y for y in core if not z3.eq(x, y) } self.mcs.append(x) self.nmcs.append(z3.Not(x)) cores += [core] else: print("solver returned %s" % is_sat) exit() mss = [x for x in self.orig_soft_vars if z3.is_true(self.model[x])] mcs = [x for x in self.orig_soft_vars if not z3.is_true(self.model[x])] self.s.add(z3.Or(mcs)) core_literals = set([]) cores.sort(key=lambda element: len(element)) for core in cores: if len(core & core_literals) == 0: self.relax_core(core) core_literals |= core return mss
def outline_loop (name, decls, rules, self_loops, preds, succs): assert name in self_loops loop_rules = self_loops [name] decl = decls [name] arg_indices = find_arg_indices (decl, loop_rules) print 'Indices for procedure args:', arg_indices arg_sorts = [decl.domain (i) for i in arg_indices] # make new procedure proc_sorts = arg_sorts + arg_sorts + [decl.range ()] proc_name = decl.name () + '_proc' proc_decl = z3.Function (proc_name, *proc_sorts) # update decls decls [proc_name] = proc_decl del decls [name] # factor the self-loop out into a procedure; # can have multiple self-loop rules because of the transformation # Self loop: P(x) /\ t(x,x') -> P(x') # Transformed rule: t(x,x') /\ P_proc(x',x!next) -> P_proc(x,x!next) for rule in loop_rules: curr_args = [rule.get_body_inst ().arg (i) for i in arg_indices] next_args = [z3.Const (arg.decl ().name () + '!next', arg.sort ()) \ for arg in curr_args] temp_args = [rule.get_head ().arg (i) for i in arg_indices] rule_exp = z3.ForAll (rule.qvars + next_args, z3.Implies (z3.And (proc_decl (*(temp_args + next_args)), rule.get_trans ()), proc_decl (*(curr_args + next_args)))) print 'Removing loop rule', repr (rule) rules.remove (rule) rules.append (Rule (rule_exp)) # Exit rule: P(x) /\ t(x,y) -> Q(y) # Exit rule of P_proc: t(x,y) -> P_proc (x,x) for rule in succs [name]: curr_args = [rule.get_body_inst ().arg (i) for i in arg_indices] next_args = curr_args rule_exp = z3.ForAll (rule.qvars, z3.Implies (rule.get_trans (), proc_decl (*(curr_args + next_args)))) rules.append (Rule (rule_exp)) # wire incoming and outgoing rules, via P, together # in-rule: Q(y) /\ t(y,x) -> P(x) # out-rule: P(x) /\ t(x,z) -> R(z) # new rule: Q(y) /\ t(y,x) /\ P_proc (x,x') /\ t(x',z) -> R(z) for in_rule in preds [name]: #print 'In rule:', repr (in_rule) if in_rule.has_body_pred (): from_inst = in_rule.get_body_inst () else: from_inst = z3.BoolVal (True) in_trans = in_rule.get_trans () curr_args = [in_rule.get_head ().arg (i) for i in arg_indices] for out_rule in succs [name]: #print 'Out rule:', repr (out_rule) # version all out_rule.qvars; treat body args and the rest separately sub = list () out_vars = list () # body args for i in range (decl.arity ()): v = out_rule.get_body_inst ().arg (i) if i in arg_indices: # create new variable v_new = z3.Const (v.decl ().name () + '_temp', v.sort ()) out_vars.append (v_new) else: # use corresponding arg from in_rule; it is unchanged by the loop v_new = in_rule.get_head ().arg (i) sub.append ( (v,v_new) ) # rest for v in out_rule.qvars: is_body_arg = False for i in range (decl.arity ()): if z3.eq (v, out_rule.get_body_inst ().arg (i)): is_body_arg = True break if is_body_arg: continue v_temp = z3.Const (v.decl ().name () + '_temp', v.sort ()) sub.append ( (v,v_temp) ) out_vars.append (v_temp) next_args = [z3.substitute (out_rule.get_body_inst ().arg (i), *sub)\ for i in arg_indices] out_trans = z3.substitute (out_rule.get_trans (), *sub) to_inst = z3.substitute (out_rule.get_head (), *sub) # make new rule via proc_decl rule_exp = z3.ForAll (in_rule.qvars + out_vars, z3.Implies (z3.And (from_inst, in_trans, proc_decl (*(curr_args + next_args)), out_trans), to_inst)) rule = Rule (rule_exp) if rule.is_self_loop (): insert_in_dict (self_loops, rule.get_head_pred_name (), rule) else: insert_in_dict (preds, rule.get_head_pred_name (), rule) if rule.has_body_pred (): insert_in_dict (succs, rule.get_body_pred_name (), rule) rules.append (rule) # remove incoming and outgoing rules, via P for in_rule in preds [name]: print 'Removing in rule', repr (in_rule) rules.remove (in_rule) if in_rule.has_body_pred (): remove_from_dict (succs, in_rule.get_body_pred_name (), in_rule) del preds [name] for out_rule in succs [name]: print 'Removing out rule', repr (out_rule) rules.remove (out_rule) remove_from_dict (preds, out_rule.get_head_pred_name (), out_rule) del succs [name]
def is_data_const(const): assert (z3.is_const(const)) return z3.eq(const.sort(), data_sort)
def _update_all_ref_tracker(self, instruction: Instruction): # update all the references for ref in self.call_result_references + self.timestamp_references + self.reentrancy_references: ref.update(instruction, self._stack, self.immutable_storage_references) # check if there are new references if instruction.opcode in [ 'CALL', "STATICCALL", "DELEGATECALL", "CALLCODE" ]: h = len(self._stack) - instruction.input_amount if len(self.reentrancy_references) == 0: # 如果之前没有任何Storage被读取,也就是为创建任何ReentrancyTracker,那么当前的这个Call就是有reentrancy bug的 tmp = ReentrancyTracker(instruction.addr, h, -1) tmp.buggy = True self.reentrancy_references.append(tmp) # 判断Call的目的地址是否是可变的(mutable) for ref in self.immutable_storage_references: if ref.contains(len(self._stack) - 2): break else: # 当前的目的地址包含在某一个mutable storage reference中 # 只有当目的地址不是一个确定值,也就是说不可靠的时候 # new call result reference is generated call_ref = CallResultTracker(instruction.addr, h) self.call_result_references.append(call_ref) elif instruction.opcode == "TIMESTAMP": # new timestamp reference is generated here ref = TimestampDepTracker(instruction.addr, len(self._stack)) self.timestamp_references.append(ref) elif instruction.opcode == "SLOAD": storage_addr = self._stack[-1] h = len(self._stack) - instruction.input_amount # 检查是否需要新建MutableStorageTracker if not utils.in_list(self.mutable_storage_addresses, storage_addr): # 是不可变的(immutable) for ref in self.immutable_storage_references: if utils.eq(ref.storage_addr, storage_addr): ref.new(h) ref.new_born = True break else: ref = ImmutableStorageTracker(instruction.addr, h, storage_addr, self._storage[storage_addr]) self.immutable_storage_references.append(ref) # 检查是否需要新建ReentrancyTracker for r in self.reentrancy_references: # check if there already exists the same reference try: if utils.is_symbol(storage_addr) and utils.is_symbol( r.storage_addr) and eq( simplify(r.storage_addr), simplify(storage_addr)) or not utils.is_symbol( storage_addr) and not utils.is_symbol( r.storage_addr ) and r.storage_addr == storage_addr: r.new(h) break except Exception as e: print(e) else: ref = ReentrancyTracker(instruction.addr, h, storage_addr) self.reentrancy_references.append(ref) # 更新mutable storage reference for ref in self.immutable_storage_references: if ref.new_born: ref.new_born = False else: ref.update(instruction, self._stack, None)
def exe_with_path_condition(self, instruction: Instruction, path_condition: list = []) -> PcPointer: self._update_all_ref_tracker(instruction) if instruction.opcode == "SSTORE": # save the value of every referred storage variable before SSTORE bak = {} for ref in self.reentrancy_references: bak[ref] = self._storage[ref.storage_addr] if instruction.opcode == "SSTORE" and self.pre_process: op0 = self._stack[-1] # 在可信条件下进行修改的Storage变量仍然是可信的(immutable) caller = z3.Int("Is") solver = z3.Solver() solver.add(path_condition) if "sat" == str(solver.check()): for storage_addr, storage_value in self._storage.get_storage( ).items(): if not utils.in_list(self.mutable_storage_addresses, storage_addr): # solver.add(caller & 0xffffffffffffffffffff != storage_value & 0xffffffffffffffffffff) mask = 0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff if utils.is_symbol(storage_value): # solver.add(z3.Int(str("Is") + "&" + str(mask)) != z3.Int(str(storage_value) + "&" + str(mask))) # solver.add(z3.Int(str(mask) + "&" + str("Is")) != z3.Int(str(storage_value) + "&" + str(mask))) # solver.add(z3.Int(str("Is") + "&" + str(mask)) != z3.Int(str(mask) + "&" + str(storage_value))) # solver.add(z3.Int(str(mask) + "&" + str("Is")) != z3.Int(str(mask) + "&" + str(storage_value))) # solver.add(z3.Int(str("Is")) != z3.Int(str(storage_value))) solver.add( z3.Int(str("Is")) != z3.Int(str(storage_value)) ) else: # solver.add(z3.Int(str("Is") + "&" + str(mask)) != storage_value & mask) # solver.add(z3.Int(str(mask) + "&" + str("Is")) != storage_value & mask) solver.add( z3.Int(str("Is")) != storage_value & mask) if "sat" == str(solver.check()): # caller不为任意一个可信storage变量的时候仍然可能进行SSTORE,则说明被修改的storage变量是不可靠的 if not utils.in_list(self.mutable_storage_addresses, op0): self.mutable_storage_addresses.append(op0) pc_pointer = super().exe(instruction) if instruction.opcode == "SSTORE": # check if any referred storage variable is changed after SSTORE # if len(bak) != len(self._storage): # # 如果新增了Storage变量,那么一定是做修改了 # ref.storage_changed = True # if ref.after_used_in_condition: # ref.changed_after_condition = True # if not ref.after_call: # ref.changed_before_call = True # else: # 如果Storage变量的个数没变,那么就检查每一个变量的值有没有改变 for ref, value in bak.items(): if utils.is_symbol(value) is not utils.is_symbol(self._storage[ref.storage_addr]) or \ utils.is_symbol(value) and not eq(simplify(value), simplify(self._storage[ref.storage_addr])) or \ not utils.is_symbol(value) and value != self._storage[ref.storage_addr]: ref.storage_changed = True if ref.after_used_in_condition: ref.changed_after_condition = True if not ref.after_call: ref.changed_before_call = True return pc_pointer
def notEqual(v, vars): return z3.And([v != i for i in vars if not z3.eq(v,i)])
def is_storage_primitive(storage): if storage: for index, content in storage._storage.items(): if isinstance(content, int) or not eq(content, BitVec("storage[" + str(index) + "]", 256)): return False return True
def key_fun(root): f = lambda pred: z3.eq(pred.root, root) pred = utils.find_first(f, preds) # Prepend branching factor to prioritize linear structs in lexicographic sort return (pred.struct.branching_factor, str(root))
def notEqual(v, vars): return z3.And([v != i for i in vars if not z3.eq(v, i)])
def substitute(self, other, expr): def update_term(t, args): n = len(args) # Need to pass an AstArray type into Z3_update_term, not a python list args_ast_arr = (z3.Ast * n)() for i in range(n): args_ast_arr[i] = args[i].as_ast() return _to_expr_ref( z3.Z3_update_term(t.ctx_ref(), t.as_ast(), n, args_ast_arr), t.ctx) cache = z3.AstMap(ctx=expr.ctx) for addr, state in self._expanded_global_state.items(): other_state = other.initial_contract_state(addr) cache[state.storage] = other_state.storage cache[state.balance] = other_state.balance cache[state.nonce] = other_state.nonce fnsubs = [] for k, v in self._cache.items(): if isinstance(v, mem.Memory): # Uses of Memory objects will produce an expression containing # the underlying array object (_mem). If required the index # (_idx) will have been substituted with the actual indexing # expression at that point, so we do not need to consider it # here. cache[v._mem] = getattr(other, k)()._mem elif z3.is_app(v) and v.num_args() > 0: fnsubs.append((v, getattr(other, k)())) else: cache[v] = getattr(other, k)() todo = [expr] while todo: n = todo[-1] if n in cache: todo.pop() elif z3.is_var(n): cache[n] = n todo.pop() elif z3.is_app(n): new_args = [] for i in range(n.num_args()): arg = n.arg(i) if arg not in cache: todo.append(arg) else: new_args.append(cache[arg]) # Only actually do the substitution if all the arguments have # already been processed if len(new_args) == n.num_args(): todo.pop() fn = n.decl() for oldfn, newfn in fnsubs: if z3.eq(fn, oldfn): new_fn = z3.substitute_vars(newfn, *new_args) break else: # TODO only if new_args != old_args if len(new_args) != fn.arity(): new_fn = update_term(n, new_args) else: new_fn = fn(*new_args) cache[n] = new_fn else: assert z3.is_quantifier(n) # Not currently implemented as don't use quanitifers at the # moment raise NotImplementedError() return cache[expr]