示例#1
0
def get_param_name(cd, add_color=False, func=None):
    global _func
    global _abi
    loc = match(cd, ("cd", ":loc")).loc

    if _abi is None:
        return cd

    if _func is None:
        return cd

    if "params" not in _func:
        return cd

    if type(loc) != int:
        cd = cleanup_mul_1(cd)

        if m := match(loc, ("add", 4, ("param", ":point_loc"))):
            return colorize(m.point_loc + ".length", COLOR_GREEN, add_color)

        if m := match(loc, ("add", 4, ("cd", ":point_loc"))):
            return colorize(
                str(get_param_name(
                    ("cd", m.point_loc), func=func)) + ".length",
                COLOR_GREEN,
                add_color,
            )
示例#2
0
    def make_asts(self):
        """
            we need to do ast creation from the contract, not function level,
            because some simplifications (type/field removal) require insight to all the functions,
            not just a single one
        """

        for func in self.functions:
            func.ast = self.make_ast(func.trace)

        def find_stor_masks(exp):
            if opcode(exp) == "type":
                return [exp]
            else:
                return []

        stor_masks = frozenset(
            find_f_list([f.ast for f in self.functions], find_stor_masks))

        stor_loc_to_masks = collections.defaultdict(set)
        stor_name_to_masks = collections.defaultdict(set)
        for mask in stor_masks:
            stor_loc_to_masks[get_loc(mask)].add(mask)
            stor_name_to_masks[get_name(mask)].add(mask)

        def cleanup(exp):

            if m := match(exp, ("field", 0, ("stor", ("length", ":idx")))):
                return ("stor", ("length", m.idx))

            if m := match(exp, ("type", 256, ("field", 0,
                                              ("stor", ("length", ":idx"))))):
                return ("stor", ("length", m.idx))
示例#3
0
            def add_path(line):
                if m := match(line, ("goto", Any, ":svs")):
                    path2 = path
                    for _, v_idx, v_val in m.svs:
                        path2 = replace(path2, ("var", v_idx), v_val)

                    return path2 + [line]
示例#4
0
def simplify(exp):

    if opcode(exp) == "max":
        terms = exp[1:]
        els = [simplify(e) for e in terms]
        res = -(2**256)
        for e in els:
            try:
                res = max_op(res, e)
            except Exception:
                return ("max", ) + tuple(els)
        return res

    if (m := match(exp, ("mask_shl", ":size", ":offset", ":shl", ":val"))):
        size, offset, shl, val = (
            simplify(m.size),
            simplify(m.offset),
            simplify(m.shl),
            simplify(m.val),
        )

        if all_concrete(size, offset, shl, val):
            return apply_mask(val, size, offset, shl)

        if (size, offset, shl) == (256, 0, 0):
            return val
示例#5
0
 def loc_to_name(exp):
     if m := match(exp, ("loc", ":int:num")):
         num = m.num
         if num < 1000:
             return ("name", "stor" + str(num), num)
         else:
             return ("name", "stor" + hex(num)[2:6].upper(), num)
示例#6
0
def pretty_stor(exp, add_color=True):
    col = partial(colorize, color=COLOR_GREEN, add_color=add_color)
    stor = partial(pretty_stor, add_color=add_color)
    pret = partial(prettify, parentheses=False, add_color=add_color)

    if m := match(exp, ("stor", ("length", ":idx"))):
        return stor(m.idx) + col(".length")
示例#7
0
    def make_trace(self):
        if self.trace is None:
            return [("undefined", "decompilation didn't finish")]

        begin_vars = []
        if self.is_label():
            for _, var_idx, var_val, _ in self.label.begin_vars:
                begin_vars.append(("setvar", var_idx, var_val))

        if self.vm.just_fdests and self.trace != [("revert", 0)]:
            t = self.trace[0]
            if match(t, ("jump", ":target_node", ...)):
                begin = [("jd", str(self.jd[0]))]  # , str(self.trace))]
            else:
                begin = ["?"]
        else:
            begin = []

        begin += [("label", self, tuple(begin_vars))] if self.is_label() else []

        last = self.trace[-1]

        if opcode(last) == "jump":
            return begin + self.trace[:-1] + last[1].make_trace()

        if m := match(last, ("if", ":cond", ":if_true", ":if_false")):
            if_true = m.if_true.make_trace()
            if_false = m.if_false.make_trace()
            return begin + self.trace[:-1] + [("if", m.cond, if_true, if_false)]
示例#8
0
def apply_mask_to_storage(exp, size, offset, shl):
    m = match(exp, ("storage", ":stor_size", ":stor_offset", ":stor_idx"))
    assert m
    stor_size, stor_offset, stor_idx = m.stor_size, m.stor_offset, m.stor_idx

    #    shr = minus_op(shl)

    stor_offset = add_op(stor_offset, offset)
    stor_size = sub_op(stor_size, offset)
    shl = add_op(shl, offset)
    offset = 0

    if safe_lt_op(size, stor_size):
        stor_size = size

    if safe_le_op(stor_size, 0) is True:
        return 0

    res = ("storage", stor_size, stor_offset, stor_idx)

    shr = 0

    if shl == 0:
        return res
    else:
        if (m := match(res,
                       ("storage", size, 0, ":stor_idx"))) and offset == 0:
            stor_idx = m.stor_idx
            shr = minus_op(shl)
            return ("storage", size, shr, stor_idx)
示例#9
0
    def simplify_string_getter_from_storage(self):
        """
            a heuristic for finding string getters and replacing them
            with a simplified version

            test cases: unicorn
                        0xF7dF66B1D0203d362D7a3afBFd6728695Ae22619 name
                        0xf8e386EDa857484f5a12e4B5DAa9984E06E73705 version

            if you want to see how it works, turn this func off
            and see how test cases decompile
        """

        if not self.read_only:
            return

        if len(self.returns) == 0:
            return

        for r in self.returns:
            if not (m := match(
                    r,
                (
                    "return",
                    ("data", ("arr", ("storage", 256, 0,
                                      ("length", ":loc")), ...)),
                ),
            )):
                return
            loc = m.loc
示例#10
0
 def other_2(exp):
     if (m := match(
             exp,
         ("if", ("eq", ":a", ":b"), ":if_true"))) and m.if_true == [
             ("return", ("eq", m.a, m.b))
         ]:
         return ("if", ("eq", m.a, m.b), [("return", ("bool", 1))])
示例#11
0
def make(trace):
    res = []

    for idx, line in enumerate(trace):
        if m := match(line, ("if", ":cond", ":if_true", ":if_false")):
            res.append(("if", m.cond, make(m.if_true), make(m.if_false)))

        elif m := match(line, ("label", ":jd", ":vars", ...)):
            jd, vars = m.jd, m.vars
            try:
                before, inside, remaining, cond = to_while(trace[idx + 1:], jd)
            except Exception:
                continue

            inside = inside  # + [str(inside)]

            inside = make(inside)
            remaining = make(remaining)

            for _, v_idx, v_val in vars:
                before = replace(before, ("var", v_idx), v_val)
            before = make(before)

            res.extend(before)
            res.append(("while", cond, inside, repr(jd), vars))
            res.extend(remaining)

            return res
示例#12
0
def to_while(trace, jd, path=None):
    path = path or []

    while True:
        if trace == []:
            raise
        line = trace[0]
        trace = trace[1:]

        if m := match(line, ("if", ":cond", ":if_true", ":if_false")):
            cond, if_true, if_false = m.cond, m.if_true, m.if_false
            if is_revert(if_true):
                path.append(("require", is_zero(cond)))
                trace = if_false
                continue

            if is_revert(if_false):
                path.append(("require", cond))
                trace = if_true
                continue

            jds_true = find_f_list(if_true, get_jds)
            jds_false = find_f_list(if_false, get_jds)

            assert (jd in jds_true) != (jd in jds_false), (jds_true, jds_false)

            def add_path(line):
                if m := match(line, ("goto", Any, ":svs")):
                    path2 = path
                    for _, v_idx, v_val in m.svs:
                        path2 = replace(path2, ("var", v_idx), v_val)

                    return path2 + [line]
                else:
                    return [line]
示例#13
0
    def simplify_sha3(e):
        e = rainbow_sha3(e)

        if match(e, ("sha3", ("data", ...))):
            terms = e[1][1:]  # "..."
            e = ("sha3", ) + tuple(terms)
        if m := match(e, ("sha3", ":int:loc")):
            return ("loc", m.loc)
示例#14
0
def _mask_op(exp, size=256, offset=0, shl=0, shr=0):
    if size == 0:
        return 0
    #    if (size, offset, shl, shr) == (256, 0, 0, 0):
    #        return exp

    if m := match(exp, ("div", ":num", 1)):
        exp = m.num  # should be done somewhere else, but it's 0:37 at night
示例#15
0
 def other_1(exp):
     if ((m := match(
             exp,
         ("mask_shl", ":int:size", ":n_size", ":size_n", ":str:val")))
             and 256 - m.size == m.n_size and m.size - 256 == m.size_n
             and m.size + 16 == len(m.val) * 8 and len(m.val) > 0 and
             m.val[0] == m.val[-1] == "'"):  # +16 because '' in strings
         return m.val
示例#16
0
def lt_op(left, right):  # left < right
    if type(left) == int and type(right) == int:
        return left < right

    if (m := match(left,
                   ("add", ":int:num", ":max"))) and opcode(m.max) == "max":
        terms = m.max[1:]
        left = ("max", ) + tuple(add_op(t, m.num) for t in terms)
示例#17
0
    def add_to_arr(exp):
        if m := match(exp, ("add", ":left", ":right")):
            left, right = m.left, m.right
            if opcode(left) == "loc":
                right, left = left, right

            if opcode(right) == "loc":
                return ("array", left, right)
示例#18
0
            def find_default(exp):

                if (m := match(
                        exp,
                    ("if", ":cond", ":if_true", ":if_false"))) and str(
                        ("cd", 0)) in str(m.cond):
                    if find_f_list(m.if_false, func_calls) == []:
                        fi = m.if_false[0]
                        if m2 := match(fi, ("jd", ":jd")):
                            return int(m2.jd)
示例#19
0
def to_bytes(exp):
    if type(exp) == int:
        return (exp + 7) // 8, exp % 8

    if type(exp) == tuple and exp[:4] == ("mask_shl", 253, 0, 3):
        return exp[4], 0

    if (m := match(exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl",
                         ":val"))) and m.shl >= 3:
        return ("mask_shl", m.size, m.offset, m.shl - 3, m.val), 0
示例#20
0
 def get_type(stordefs):
     sizes = set()
     offsets = set()
     for s in stordefs:
         if (m := match(
                 s, ("stor", ":size", ":off",
                     (":op", ":idx", ...)))) and m.op in ("map", "array"):
             sizes.add(m.size)
             if safe_le_op(0, m.off) is True:
                 offsets.add(m.off)
示例#21
0
def __try_add(self, other):
    if (m := match(
            self,
        ("mul", ":num",
         ("mask_shl", ":int:size", ":int:off", ":int:shl", ":val")),
    )) and m.shl > 0:
        self = (
            "mul",
            m.num + 2**m.shl,
            ("mask_shl", m.size + m.shl, m.off, 0, m.val),
        )
示例#22
0
    def f(exp):
        if type(exp) != tuple:
            return None

        elif opcode(exp) in ("storage", "stor"):
            #        elif exp ~ ('storage', ...): # if we have a storage reference within the storage
            # don't follow this one when looking for loc
            return None

        elif m := match(exp, ("loc", ":num")):
            return m.num
示例#23
0
def mask_to_mul(exp):
    if m := match(
            exp, ("mask_shl", ":int:size", ":int:offset", ":int:shl", ":val")):
        size, offset, shl, val = m.size, m.offset, m.shl, m.val
        if shl > 0 and offset == 0 and size == 256 - shl:
            if shl <= 8:
                return ("mul", 2**shl, val)

        if shl < 0 and offset == -shl and size == 256 - offset:
            if shl >= -8:
                return ("div", 2**shl, val)
示例#24
0
def slice_exp(exp, left, right):
    size = sub_op(right, left)

    logger.debug(f"slicing {exp}, offset {left} bytes, until {right} bytes")
    # e.g. mem[32 len 10], 2, 4 == mem[34,2]

    if m := match(exp, ("mem", ("range", ":rleft", ":rlen"))):
        rleft, rlen = m.rleft, m.rlen
        if safe_le_op(add_op(left, size), rlen):
            return ("mem", ("range", add_op(rleft, left), size))
        else:
            return None
示例#25
0
def pretty_line(r, add_color=True):
    col = partial(colorize, add_color=add_color)
    pret = partial(prettify, parentheses=False, add_color=add_color)

    if type(r) is str:
        yield COLOR_GRAY + "# " + r + ENDC

    #    elif r ~ ('jumpdest', ...):
    #        pass

    elif m := match(r, ("comment", ":text")):
        yield COLOR_GRAY + "# " + prettify(m.text, add_color=False) + ENDC
示例#26
0
        def arr_rem_mul(exp):
            if m := match(
                    exp,
                ("array",
                 ("mask_shl", ":size", ":off", ":int:shl", ":idx"), ":loc"),
            ):
                size, off, shl, idx, loc = m.size, m.off, m.shl, m.idx, m.loc
                r = 2**shl
                e_loc = get_loc(loc)

                for s in self.stor_defs:
                    assert match(s, ("def", Any, ":d_loc", ":d_def"))
                    if match(s, ("def", Any, e_loc, ("array", ("struct", r)))):
                        return ("array", ("mask_shl", size, off, 0, idx), loc)
示例#27
0
def fill_mem(exp, mem_idx, mem_val):

    # speed - if exp contains a variable used in mem_idx
    #         or mem_idx contains a variable not used in exp
    #         there can be no match.
    #
    #         ugly, but shaves off 15% exec time
    logger.debug(f"filling mem: {exp} with mem[{mem_idx}] == {mem_val}")

    if (m := match(
            mem_idx,
        ("range",
         ("var", ":num"), Any))) and not contains(exp, ("var", m.num)):
        assert not strict
        return exp
示例#28
0
def find_storage_names(functions):

    res = {}

    for func in functions:
        if func.getter:
            getter = func.getter

            assert opcode(getter) in ("storage", "struct", "bool")

            # func name into potential storage name

            new_name = func.name

            if new_name[:3] == "get" and len(new_name.split("(")[0]) > 3:
                new_name = new_name[3:]

            if new_name != new_name.upper():
                # otherwise we get stuff like bILLIONS in 0xF0160428a8552AC9bB7E050D90eEADE4DDD52843
                new_name = new_name[0].lower() + new_name[1:]

            new_name = new_name.split("(")[0]

            if match(getter, ("storage", 160, ...)):
                if (("address" not in new_name.lower())
                        and ("addr" not in new_name.lower())
                        and ("account" not in new_name.lower())
                        and ("owner" not in new_name.lower())):
                    new_name += "Address"

            res[getter] = new_name

    return res
示例#29
0
def _try_add(self, other):
    # tries to add (mul a x) (mul b y)
    # 'self' name to be refactored

    #   so proud of this /s

    if not match(self, ("mul", int, Any)) or not match(other,
                                                       ("mul", int, Any)):
        return None

    if ((ms := match(self, ("mul", -1, ":val"))) and (mo := match(
            other,
        ("mul", ":mul",
         ("mask_shl", ":int:other_size", 0, ":int:shl", ms.val)),
    )) and mo.other_size == 256 - mo.shl):
        mo.mul *= 2**mo.shl - 1
        return mul_op(mo.mul, ms.val)
示例#30
0
def split_or(value):
    orig_value = value

    if opcode(value) not in ("or", "mask_shl"):
        return [(256, 0, value)]

    if opcode(value) == "mask_shl":
        value = ("or", value)

    opcode_, *terms = value
    assert opcode_ == "or"

    ret_rows = []

    for row in terms:
        if m := match(row, ("bool", ":arg")):
            row = (
                "mask_shl",
                8,
                0,
                0,
                ("bool", m.arg),
            )  # does weird things if size == 1, in loops.activateSafeMode

        if row == "caller":
            row = (
                "mask_shl",
                160,
                0,
                0,
                "caller",
            )  # does weird things if size == 1, in loops.activateSafeMode

        if row == "block.timestamp":
            row = (
                "mask_shl",
                64,
                0,
                0,
                "caller",
            )  # does weird things if size == 1, in loops.activateSafeMode

        if m := match(row, ("mul", 1, ":val")):
            row = m.val