def get_param_name(cd, add_color=False, func=None): global _func global _abi loc = match(cd, ("cd", ":loc")).loc if _abi is None: return cd if _func is None: return cd if "params" not in _func: return cd if type(loc) != int: cd = cleanup_mul_1(cd) if m := match(loc, ("add", 4, ("param", ":point_loc"))): return colorize(m.point_loc + ".length", COLOR_GREEN, add_color) if m := match(loc, ("add", 4, ("cd", ":point_loc"))): return colorize( str(get_param_name( ("cd", m.point_loc), func=func)) + ".length", COLOR_GREEN, add_color, )
def make_asts(self): """ we need to do ast creation from the contract, not function level, because some simplifications (type/field removal) require insight to all the functions, not just a single one """ for func in self.functions: func.ast = self.make_ast(func.trace) def find_stor_masks(exp): if opcode(exp) == "type": return [exp] else: return [] stor_masks = frozenset( find_f_list([f.ast for f in self.functions], find_stor_masks)) stor_loc_to_masks = collections.defaultdict(set) stor_name_to_masks = collections.defaultdict(set) for mask in stor_masks: stor_loc_to_masks[get_loc(mask)].add(mask) stor_name_to_masks[get_name(mask)].add(mask) def cleanup(exp): if m := match(exp, ("field", 0, ("stor", ("length", ":idx")))): return ("stor", ("length", m.idx)) if m := match(exp, ("type", 256, ("field", 0, ("stor", ("length", ":idx"))))): return ("stor", ("length", m.idx))
def add_path(line): if m := match(line, ("goto", Any, ":svs")): path2 = path for _, v_idx, v_val in m.svs: path2 = replace(path2, ("var", v_idx), v_val) return path2 + [line]
def simplify(exp): if opcode(exp) == "max": terms = exp[1:] els = [simplify(e) for e in terms] res = -(2**256) for e in els: try: res = max_op(res, e) except Exception: return ("max", ) + tuple(els) return res if (m := match(exp, ("mask_shl", ":size", ":offset", ":shl", ":val"))): size, offset, shl, val = ( simplify(m.size), simplify(m.offset), simplify(m.shl), simplify(m.val), ) if all_concrete(size, offset, shl, val): return apply_mask(val, size, offset, shl) if (size, offset, shl) == (256, 0, 0): return val
def loc_to_name(exp): if m := match(exp, ("loc", ":int:num")): num = m.num if num < 1000: return ("name", "stor" + str(num), num) else: return ("name", "stor" + hex(num)[2:6].upper(), num)
def pretty_stor(exp, add_color=True): col = partial(colorize, color=COLOR_GREEN, add_color=add_color) stor = partial(pretty_stor, add_color=add_color) pret = partial(prettify, parentheses=False, add_color=add_color) if m := match(exp, ("stor", ("length", ":idx"))): return stor(m.idx) + col(".length")
def make_trace(self): if self.trace is None: return [("undefined", "decompilation didn't finish")] begin_vars = [] if self.is_label(): for _, var_idx, var_val, _ in self.label.begin_vars: begin_vars.append(("setvar", var_idx, var_val)) if self.vm.just_fdests and self.trace != [("revert", 0)]: t = self.trace[0] if match(t, ("jump", ":target_node", ...)): begin = [("jd", str(self.jd[0]))] # , str(self.trace))] else: begin = ["?"] else: begin = [] begin += [("label", self, tuple(begin_vars))] if self.is_label() else [] last = self.trace[-1] if opcode(last) == "jump": return begin + self.trace[:-1] + last[1].make_trace() if m := match(last, ("if", ":cond", ":if_true", ":if_false")): if_true = m.if_true.make_trace() if_false = m.if_false.make_trace() return begin + self.trace[:-1] + [("if", m.cond, if_true, if_false)]
def apply_mask_to_storage(exp, size, offset, shl): m = match(exp, ("storage", ":stor_size", ":stor_offset", ":stor_idx")) assert m stor_size, stor_offset, stor_idx = m.stor_size, m.stor_offset, m.stor_idx # shr = minus_op(shl) stor_offset = add_op(stor_offset, offset) stor_size = sub_op(stor_size, offset) shl = add_op(shl, offset) offset = 0 if safe_lt_op(size, stor_size): stor_size = size if safe_le_op(stor_size, 0) is True: return 0 res = ("storage", stor_size, stor_offset, stor_idx) shr = 0 if shl == 0: return res else: if (m := match(res, ("storage", size, 0, ":stor_idx"))) and offset == 0: stor_idx = m.stor_idx shr = minus_op(shl) return ("storage", size, shr, stor_idx)
def simplify_string_getter_from_storage(self): """ a heuristic for finding string getters and replacing them with a simplified version test cases: unicorn 0xF7dF66B1D0203d362D7a3afBFd6728695Ae22619 name 0xf8e386EDa857484f5a12e4B5DAa9984E06E73705 version if you want to see how it works, turn this func off and see how test cases decompile """ if not self.read_only: return if len(self.returns) == 0: return for r in self.returns: if not (m := match( r, ( "return", ("data", ("arr", ("storage", 256, 0, ("length", ":loc")), ...)), ), )): return loc = m.loc
def other_2(exp): if (m := match( exp, ("if", ("eq", ":a", ":b"), ":if_true"))) and m.if_true == [ ("return", ("eq", m.a, m.b)) ]: return ("if", ("eq", m.a, m.b), [("return", ("bool", 1))])
def make(trace): res = [] for idx, line in enumerate(trace): if m := match(line, ("if", ":cond", ":if_true", ":if_false")): res.append(("if", m.cond, make(m.if_true), make(m.if_false))) elif m := match(line, ("label", ":jd", ":vars", ...)): jd, vars = m.jd, m.vars try: before, inside, remaining, cond = to_while(trace[idx + 1:], jd) except Exception: continue inside = inside # + [str(inside)] inside = make(inside) remaining = make(remaining) for _, v_idx, v_val in vars: before = replace(before, ("var", v_idx), v_val) before = make(before) res.extend(before) res.append(("while", cond, inside, repr(jd), vars)) res.extend(remaining) return res
def to_while(trace, jd, path=None): path = path or [] while True: if trace == []: raise line = trace[0] trace = trace[1:] if m := match(line, ("if", ":cond", ":if_true", ":if_false")): cond, if_true, if_false = m.cond, m.if_true, m.if_false if is_revert(if_true): path.append(("require", is_zero(cond))) trace = if_false continue if is_revert(if_false): path.append(("require", cond)) trace = if_true continue jds_true = find_f_list(if_true, get_jds) jds_false = find_f_list(if_false, get_jds) assert (jd in jds_true) != (jd in jds_false), (jds_true, jds_false) def add_path(line): if m := match(line, ("goto", Any, ":svs")): path2 = path for _, v_idx, v_val in m.svs: path2 = replace(path2, ("var", v_idx), v_val) return path2 + [line] else: return [line]
def simplify_sha3(e): e = rainbow_sha3(e) if match(e, ("sha3", ("data", ...))): terms = e[1][1:] # "..." e = ("sha3", ) + tuple(terms) if m := match(e, ("sha3", ":int:loc")): return ("loc", m.loc)
def _mask_op(exp, size=256, offset=0, shl=0, shr=0): if size == 0: return 0 # if (size, offset, shl, shr) == (256, 0, 0, 0): # return exp if m := match(exp, ("div", ":num", 1)): exp = m.num # should be done somewhere else, but it's 0:37 at night
def other_1(exp): if ((m := match( exp, ("mask_shl", ":int:size", ":n_size", ":size_n", ":str:val"))) and 256 - m.size == m.n_size and m.size - 256 == m.size_n and m.size + 16 == len(m.val) * 8 and len(m.val) > 0 and m.val[0] == m.val[-1] == "'"): # +16 because '' in strings return m.val
def lt_op(left, right): # left < right if type(left) == int and type(right) == int: return left < right if (m := match(left, ("add", ":int:num", ":max"))) and opcode(m.max) == "max": terms = m.max[1:] left = ("max", ) + tuple(add_op(t, m.num) for t in terms)
def add_to_arr(exp): if m := match(exp, ("add", ":left", ":right")): left, right = m.left, m.right if opcode(left) == "loc": right, left = left, right if opcode(right) == "loc": return ("array", left, right)
def find_default(exp): if (m := match( exp, ("if", ":cond", ":if_true", ":if_false"))) and str( ("cd", 0)) in str(m.cond): if find_f_list(m.if_false, func_calls) == []: fi = m.if_false[0] if m2 := match(fi, ("jd", ":jd")): return int(m2.jd)
def to_bytes(exp): if type(exp) == int: return (exp + 7) // 8, exp % 8 if type(exp) == tuple and exp[:4] == ("mask_shl", 253, 0, 3): return exp[4], 0 if (m := match(exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl", ":val"))) and m.shl >= 3: return ("mask_shl", m.size, m.offset, m.shl - 3, m.val), 0
def get_type(stordefs): sizes = set() offsets = set() for s in stordefs: if (m := match( s, ("stor", ":size", ":off", (":op", ":idx", ...)))) and m.op in ("map", "array"): sizes.add(m.size) if safe_le_op(0, m.off) is True: offsets.add(m.off)
def __try_add(self, other): if (m := match( self, ("mul", ":num", ("mask_shl", ":int:size", ":int:off", ":int:shl", ":val")), )) and m.shl > 0: self = ( "mul", m.num + 2**m.shl, ("mask_shl", m.size + m.shl, m.off, 0, m.val), )
def f(exp): if type(exp) != tuple: return None elif opcode(exp) in ("storage", "stor"): # elif exp ~ ('storage', ...): # if we have a storage reference within the storage # don't follow this one when looking for loc return None elif m := match(exp, ("loc", ":num")): return m.num
def mask_to_mul(exp): if m := match( exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl", ":val")): size, offset, shl, val = m.size, m.offset, m.shl, m.val if shl > 0 and offset == 0 and size == 256 - shl: if shl <= 8: return ("mul", 2**shl, val) if shl < 0 and offset == -shl and size == 256 - offset: if shl >= -8: return ("div", 2**shl, val)
def slice_exp(exp, left, right): size = sub_op(right, left) logger.debug(f"slicing {exp}, offset {left} bytes, until {right} bytes") # e.g. mem[32 len 10], 2, 4 == mem[34,2] if m := match(exp, ("mem", ("range", ":rleft", ":rlen"))): rleft, rlen = m.rleft, m.rlen if safe_le_op(add_op(left, size), rlen): return ("mem", ("range", add_op(rleft, left), size)) else: return None
def pretty_line(r, add_color=True): col = partial(colorize, add_color=add_color) pret = partial(prettify, parentheses=False, add_color=add_color) if type(r) is str: yield COLOR_GRAY + "# " + r + ENDC # elif r ~ ('jumpdest', ...): # pass elif m := match(r, ("comment", ":text")): yield COLOR_GRAY + "# " + prettify(m.text, add_color=False) + ENDC
def arr_rem_mul(exp): if m := match( exp, ("array", ("mask_shl", ":size", ":off", ":int:shl", ":idx"), ":loc"), ): size, off, shl, idx, loc = m.size, m.off, m.shl, m.idx, m.loc r = 2**shl e_loc = get_loc(loc) for s in self.stor_defs: assert match(s, ("def", Any, ":d_loc", ":d_def")) if match(s, ("def", Any, e_loc, ("array", ("struct", r)))): return ("array", ("mask_shl", size, off, 0, idx), loc)
def fill_mem(exp, mem_idx, mem_val): # speed - if exp contains a variable used in mem_idx # or mem_idx contains a variable not used in exp # there can be no match. # # ugly, but shaves off 15% exec time logger.debug(f"filling mem: {exp} with mem[{mem_idx}] == {mem_val}") if (m := match( mem_idx, ("range", ("var", ":num"), Any))) and not contains(exp, ("var", m.num)): assert not strict return exp
def find_storage_names(functions): res = {} for func in functions: if func.getter: getter = func.getter assert opcode(getter) in ("storage", "struct", "bool") # func name into potential storage name new_name = func.name if new_name[:3] == "get" and len(new_name.split("(")[0]) > 3: new_name = new_name[3:] if new_name != new_name.upper(): # otherwise we get stuff like bILLIONS in 0xF0160428a8552AC9bB7E050D90eEADE4DDD52843 new_name = new_name[0].lower() + new_name[1:] new_name = new_name.split("(")[0] if match(getter, ("storage", 160, ...)): if (("address" not in new_name.lower()) and ("addr" not in new_name.lower()) and ("account" not in new_name.lower()) and ("owner" not in new_name.lower())): new_name += "Address" res[getter] = new_name return res
def _try_add(self, other): # tries to add (mul a x) (mul b y) # 'self' name to be refactored # so proud of this /s if not match(self, ("mul", int, Any)) or not match(other, ("mul", int, Any)): return None if ((ms := match(self, ("mul", -1, ":val"))) and (mo := match( other, ("mul", ":mul", ("mask_shl", ":int:other_size", 0, ":int:shl", ms.val)), )) and mo.other_size == 256 - mo.shl): mo.mul *= 2**mo.shl - 1 return mul_op(mo.mul, ms.val)
def split_or(value): orig_value = value if opcode(value) not in ("or", "mask_shl"): return [(256, 0, value)] if opcode(value) == "mask_shl": value = ("or", value) opcode_, *terms = value assert opcode_ == "or" ret_rows = [] for row in terms: if m := match(row, ("bool", ":arg")): row = ( "mask_shl", 8, 0, 0, ("bool", m.arg), ) # does weird things if size == 1, in loops.activateSafeMode if row == "caller": row = ( "mask_shl", 160, 0, 0, "caller", ) # does weird things if size == 1, in loops.activateSafeMode if row == "block.timestamp": row = ( "mask_shl", 64, 0, 0, "caller", ) # does weird things if size == 1, in loops.activateSafeMode if m := match(row, ("mul", 1, ":val")): row = m.val