def make_asts(self): """ we need to do ast creation from the contract, not function level, because some simplifications (type/field removal) require insight to all the functions, not just a single one """ for func in self.functions: func.ast = self.make_ast(func.trace) def find_stor_masks(exp): if opcode(exp) == "type": return [exp] else: return [] mlist = set( find_f_list([f.ast for f in self.functions], find_stor_masks)) def cleanup(exp): if m := match(exp, ("field", 0, ("stor", ("length", ":idx")))): return ("stor", ("length", m.idx)) if m := match(exp, ("type", 256, ("field", 0, ("stor", ("length", ":idx"))))): return ("stor", ("length", m.idx))
def simplify_string_getter_from_storage(self): """ a heuristic for finding string getters and replacing them with a simplified version test cases: unicorn 0xF7dF66B1D0203d362D7a3afBFd6728695Ae22619 name 0xf8e386EDa857484f5a12e4B5DAa9984E06E73705 version if you want to see how it works, turn this func off and see how test cases decompile """ if not self.read_only: return if len(self.returns) == 0: return for r in self.returns: if not (m := match( r, ( "return", ("data", ("arr", ("storage", 256, 0, ("length", ":loc")), ...)), ), )): return loc = m.loc
def loc_to_name(exp): if m := match(exp, ("loc", ":int:num")): num = m.num if num < 1000: return ("name", "stor" + str(num), num) else: return ("name", "stor" + hex(num)[2:6].upper(), num)
def rewrite_memcpy(lines): # 2 assert len(lines) == 2 l1 = lines[0] l2 = lines[1] if m := match( l1, ( "setmem", ( "range", ":s", ("mask_shl", 251, 5, 0, ("add", 31, ("cd", ("add", 4, ":param")))), ), ( "data", ("call.data", ("add", 36, ":param"), ("cd", ("add", 4, ":param"))), ("mem", ...), ), ), ): return ( "setmem", ("range", s, ("cd", ("add", 4, m.param))), ("call.data", ("add", 36, m.param), ("cd", ("add", 4, m.param))), )
def pretty_stor(exp, add_color=True): col = partial(colorize, color=COLOR_GREEN, add_color=add_color) stor = partial(pretty_stor, add_color=add_color) pret = partial(prettify, parentheses=False, add_color=add_color) if m := match(exp, ("stor", ("length", ":idx"))): return stor(m.idx) + col(".length")
def postprocess_trace(line): """ let's find all the stuff like if (some_len % 32) == 0: return Array(some_len, some_stuff) else: mem[...] = leftover return Array(some_len, some_stuff, leftover) and replace it with just return Array(some_len, some_stuff) in theory this is incorrect, because perhaps program does something totally different in the one branch, andd something entirely different in another. but this cleans up tremendous amounts of output, and didn't find a counterexample yet. """ # if line ~ ('setmem', ('range', :s, ('mask_shl', 251, 5, 0, ('add', 31, ('cd', ('add', 4, :param))))), ('data', ('call.data', ('add', 36, param), ('cd', ('add', 4, param))), ('mem', ...))): # lin = ('setmem', ('range', s, ('cd', ('add', 4, param))), ('call.data', ('add', 36, param), ('cd', ('add', 4, param)))) # return [lin] if m := match(line, ("if", ("iszero", ("storage", 5, 0, ":l")), ":if_true", ":if_false")): l, if_true, if_false = m.l, m.if_true, m.if_false def find_arr_l(exp): if match(exp, ("arr", ("storage", 256, 0, l), ...)): return [exp] true_arr = find_f_list(if_true, find_arr_l) false_arr = find_f_list(if_true, find_arr_l) if len(true_arr) > 0 and len(true_arr) == len(false_arr): return if_true
def rewrite_string_stores(lines): # ugly af, and not super-precise. it should be split into 2 parts, # converting array->storage writes in loop_to_setmem_from_storage # and then relying on those storage writes here for cleanup assert len(lines) == 3 l1, l2, l3 = lines[0], lines[1], lines[2] if ((m1 := match( l1, ("store", 256, 0, ":idx", ("add", 1, ("mask_shl", 255, 0, 1, ":src"))), )) and (m2 := match(l2, ("while", ("gt", Any, Any), ":path2", Any, ":setvars"))) and match(l3, ("while", ("gt", ...), ":path3", ...)) and len(m2.path2) == 2 and (x := m2.path2[0]) and match( x, ( "store", 256, 0, ("add", ("var", Any), Any), ("mem", ("range", ("var", ":v"), 32)), ), )): return [( "store", 256, 0, ("array", "", ("sha3", m1.idx)), ("arr", m1.src, ("mem", ("range", m2.setvars[1][2], m1.src))), )]
def other_2(exp): if (m := match( exp, ("if", ("eq", ":a", ":b"), ":if_true"))) and m.if_true == [ ("return", ("eq", m.a, m.b)) ]: return ("if", ("eq", m.a, m.b), [("return", ("bool", 1))])
def split_or(value): orig_value = value if opcode(value) not in ('or', 'mask_shl'): return [(256,0,value)] if opcode(value) == 'mask_shl': value = ('or', value) opcode_, *terms = value assert opcode_ == 'or' ret_rows = [] for row in terms: if m := match(row, ('bool', ':arg')): row = ('mask_shl', 8, 0, 0, ('bool', m.arg)) # does weird things if size == 1, in loops.activateSafeMode if row == 'caller': row = ('mask_shl', 160, 0, 0, 'caller') # does weird things if size == 1, in loops.activateSafeMode if row == 'block.timestamp': row = ('mask_shl', 64, 0, 0, 'caller') # does weird things if size == 1, in loops.activateSafeMode if m := match(row, ('mul', 1, ':val')): row = m.val
def apply_mask_to_storage(exp, size, offset, shl): m = match(exp, ("storage", ":stor_size", ":stor_offset", ":stor_idx")) assert m stor_size, stor_offset, stor_idx = m.stor_size, m.stor_offset, m.stor_idx # shr = minus_op(shl) stor_offset = add_op(stor_offset, offset) stor_size = sub_op(stor_size, offset) shl = add_op(shl, offset) offset = 0 if safe_lt_op(size, stor_size): stor_size = size if safe_le_op(stor_size, 0) is True: return 0 res = ("storage", stor_size, stor_offset, stor_idx) shr = 0 if shl == 0: return res else: if (m := match(res, ("storage", size, 0, ":stor_idx"))) and offset == 0: stor_idx = m.stor_idx shr = minus_op(shl) return ("storage", size, shr, stor_idx)
def add_path(line): if m := match(line, ("goto", Any, ":svs")): path2 = path for _, v_idx, v_val in m.svs: path2 = replace(path2, ("var", v_idx), v_val) return path2 + [line]
def make_trace(self): if self.trace is None: return [("undefined", "decompilation didn't finish")] begin_vars = [] if self.is_label(): for _, var_idx, var_val, _ in self.label.begin_vars: begin_vars.append(("setvar", var_idx, var_val)) if self.vm.just_fdests and self.trace != [("revert", 0)]: t = self.trace[0] if match(t, ("jump", ":target_node", ...)): begin = [("jd", str(self.jd[0]))] # , str(self.trace))] else: begin = ["?"] else: begin = [] begin += [("label", self, tuple(begin_vars))] if self.is_label() else [] last = self.trace[-1] if opcode(last) == "jump": return begin + self.trace[:-1] + last[1].make_trace() if m := match(last, ("if", ":cond", ":if_true", ":if_false")): if_true = m.if_true.make_trace() if_false = m.if_false.make_trace() return begin + self.trace[:-1] + [("if", m.cond, if_true, if_false) ]
def to_while(trace, jd, path=None): path = path or [] while True: if trace == []: raise line = trace[0] trace = trace[1:] if m := match(line, ("if", ":cond", ":if_true", ":if_false")): cond, if_true, if_false = m.cond, m.if_true, m.if_false if is_revert(if_true): path.append(("require", is_zero(cond))) trace = if_false continue if is_revert(if_false): path.append(("require", cond)) trace = if_true continue jds_true = find_f_list(if_true, get_jds) jds_false = find_f_list(if_false, get_jds) assert (jd in jds_true) != (jd in jds_false), (jds_true, jds_false) def add_path(line): if m := match(line, ("goto", Any, ":svs")): path2 = path for _, v_idx, v_val in m.svs: path2 = replace(path2, ("var", v_idx), v_val) return path2 + [line] else: return [line]
def make(trace): res = [] for idx, line in enumerate(trace): if m := match(line, ("if", ":cond", ":if_true", ":if_false")): res.append(("if", m.cond, make(m.if_true), make(m.if_false))) elif m := match(line, ("label", ":jd", ":vars", ...)): jd, vars = m.jd, m.vars try: before, inside, remaining, cond = to_while(trace[idx + 1:], jd) except Exception: continue inside = inside # + [str(inside)] inside = make(inside) remaining = make(remaining) for _, v_idx, v_val in vars: before = replace(before, ("var", v_idx), v_val) before = make(before) res.extend(before) res.append(("while", cond, inside, repr(jd), vars)) res.extend(remaining) return res
def get_param_name(cd, add_color=False, func=None): global _func global _abi loc = match(cd, ("cd", ":loc")).loc if _abi is None: return cd if _func is None: return cd if "params" not in _func: return cd if type(loc) != int: cd = cleanup_mul_1(cd) if m := match(loc, ("add", 4, ("param", ":point_loc"))): return colorize(m.point_loc + ".length", COLOR_GREEN, add_color) if m := match(loc, ("add", 4, ("cd", ":point_loc"))): return colorize( str(get_param_name( ("cd", m.point_loc), func=func)) + ".length", COLOR_GREEN, add_color, )
def simplify(exp): if opcode(exp) == "max": terms = exp[1:] els = [simplify(e) for e in terms] res = -(2 ** 256) for e in els: try: res = max_op(res, e) except: return ("max",) + tuple(els) return res if (m := match(exp, ("mask_shl", ":size", ":offset", ":shl", ":val"))) : size, offset, shl, val = ( simplify(m.size), simplify(m.offset), simplify(m.shl), simplify(m.val), ) if all_concrete(size, offset, shl, val): return apply_mask(val, size, offset, shl) if (size, offset, shl) == (256, 0, 0): return val
def lt_op(left, right): # left < right if type(left) == int and type(right) == int: return left < right if (m := match(left, ("add", ":int:num", ":max"))) and opcode(m.max) == "max": terms = m.max[1:] left = ("max",) + tuple(add_op(t, m.num) for t in terms)
def add_to_arr(exp): if m := match(exp, ("add", ":left", ":right")): left, right = m.left, m.right if opcode(left) == "loc": right, left = left, right if opcode(right) == "loc": return ("array", left, right)
def other_1(exp): if ((m := match( exp, ("mask_shl", ":int:size", ":n_size", ":size_n", ":str:val"))) and 256 - m.size == m.n_size and m.size - 256 == m.size_n and m.size + 16 == len(m.val) * 8 and len(m.val) > 0 and m.val[0] == m.val[-1] == "'"): # +16 because '' in strings return m.val
def _mask_op(exp, size=256, offset=0, shl=0, shr=0): if size == 0: return 0 # if (size, offset, shl, shr) == (256, 0, 0, 0): # return exp if m := match(exp, ("div", ":num", 1)): exp = m.num # should be done somewhere else, but it's 0:37 at night
def simplify_sha3(e): e = rainbow_sha3(e) if match(e, ("sha3", ("data", ...))): terms = e[1][1:] # "..." e = ("sha3", ) + tuple(terms) if m := match(e, ("sha3", ":int:loc")): return ("loc", m.loc)
def find_default(exp): if (m := match(exp, ("if", ":cond", ":if_true", ":if_false"))) and str( ("cd", 0) ) in str(m.cond): if find_f_list(m.if_false, func_calls) == []: fi = m.if_false[0] if m2 := match(fi, ("jd", ":jd")): return int(m2.jd)
def apply_vars(var_list): for orig_name, new_name in var_list: assert match(orig_name, ("var", ":name")) assert match(new_name, ("var", int)) self.trace = replace(self.trace, orig_name, new_name) for n in self.next: n.apply_vars(var_list)
def get_type(stordefs): sizes = set() offsets = set() for s in stordefs: if (m := match( s, ("stor", ":size", ":off", (":op", ":idx", ...)))) and m.op in ("map", "array"): sizes.add(m.size) if safe_le_op(0, m.off) is True: offsets.add(m.off)
def to_bytes(exp): if type(exp) == int: return (exp + 7) // 8, exp % 8 if type(exp) == tuple and exp[:4] == ("mask_shl", 253, 0, 3): return exp[4], 0 if (m := match(exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl", ":val"))) and m.shl >= 3: return ("mask_shl", m.size, m.offset, m.shl - 3, m.val), 0
def mask_to_mul(exp): if m := match( exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl", ":val")): size, offset, shl, val = m.size, m.offset, m.shl, m.val if shl > 0 and offset == 0 and size == 256 - shl: if shl <= 8: return ("mul", 2**shl, val) if shl < 0 and offset == -shl and size == 256 - offset: if shl >= -8: return ("div", 2**shl, val)
def f(exp): if type(exp) != tuple: return None elif opcode(exp) in ("storage", "stor"): # elif exp ~ ('storage', ...): # if we have a storage reference within the storage # don't follow this one when looking for loc return None elif m := match(exp, ("loc", ":num")): return m.num
def __try_add(self, other): if (m := match( self, ("mul", ":num", ("mask_shl", ":int:size", ":int:off", ":int:shl", ":val")), )) and m.shl > 0: self = ( "mul", m.num + 2**m.shl, ("mask_shl", m.size + m.shl, m.off, 0, m.val), )
def slice_exp(exp, left, right): size = sub_op(right, left) logger.debug(f"slicing {exp}, offset {left} bytes, until {right} bytes") # e.g. mem[32 len 10], 2, 4 == mem[34,2] if m := match(exp, ("mem", ("range", ":rleft", ":rlen"))): rleft, rlen = m.rleft, m.rlen if safe_le_op(add_op(left, size), rlen): return ("mem", ("range", add_op(rleft, left), size)) else: return None
def pretty_line(r, add_color=True): col = partial(colorize, add_color=add_color) pret = partial(prettify, parentheses=False, add_color=add_color) if type(r) is str: yield COLOR_GRAY + "# " + r + ENDC # elif r ~ ('jumpdest', ...): # pass elif m := match(r, ("comment", ":text")): yield COLOR_GRAY + "# " + prettify(m.text, add_color=False) + ENDC