Esempio n. 1
0
def to_while(trace, jd, path=None):
    path = path or []

    while True:
        if trace == []:
            raise
        line = trace[0]
        trace = trace[1:]

        if m := match(line, ("if", ":cond", ":if_true", ":if_false")):
            cond, if_true, if_false = m.cond, m.if_true, m.if_false
            if is_revert(if_true):
                path.append(("require", is_zero(cond)))
                trace = if_false
                continue

            if is_revert(if_false):
                path.append(("require", cond))
                trace = if_true
                continue

            jds_true = find_f_list(if_true, get_jds)
            jds_false = find_f_list(if_false, get_jds)

            assert (jd in jds_true) != (jd in jds_false), (jds_true, jds_false)

            def add_path(line):
                if m := match(line, ("goto", Any, ":svs")):
                    path2 = path
                    for _, v_idx, v_val in m.svs:
                        path2 = replace(path2, ("var", v_idx), v_val)

                    return path2 + [line]
                else:
                    return [line]
Esempio n. 2
0
def postprocess_trace(line):
    """
        let's find all the stuff like

         if (some_len % 32) == 0:
            return Array(some_len, some_stuff)
         else:
            mem[...] = leftover
            return Array(some_len, some_stuff, leftover)

        and replace it with just return Array(some_len, some_stuff)

        in theory this is incorrect, because perhaps program does something totally different
        in the one branch, andd something entirely different in another.
        but this cleans up tremendous amounts of output, and didn't find a counterexample yet.
    """

    #    if line ~ ('setmem', ('range', :s, ('mask_shl', 251, 5, 0, ('add', 31, ('cd', ('add', 4, :param))))), ('data', ('call.data', ('add', 36, param), ('cd', ('add', 4, param))), ('mem', ...))):
    #        lin = ('setmem', ('range', s, ('cd', ('add', 4, param))), ('call.data', ('add', 36, param), ('cd', ('add', 4, param))))
    #        return [lin]

    if m := match(line,
                  ("if", ("iszero",
                          ("storage", 5, 0, ":l")), ":if_true", ":if_false")):
        l, if_true, if_false = m.l, m.if_true, m.if_false

        def find_arr_l(exp):
            if match(exp, ("arr", ("storage", 256, 0, l), ...)):
                return [exp]

        true_arr = find_f_list(if_true, find_arr_l)
        false_arr = find_f_list(if_true, find_arr_l)

        if len(true_arr) > 0 and len(true_arr) == len(false_arr):
            return if_true
Esempio n. 3
0
    def make_asts(self):
        """
            we need to do ast creation from the contract, not function level,
            because some simplifications (type/field removal) require insight to all the functions,
            not just a single one
        """

        for func in self.functions:
            func.ast = self.make_ast(func.trace)

        def find_stor_masks(exp):
            if opcode(exp) == "type":
                return [exp]
            else:
                return []

        mlist = set(
            find_f_list([f.ast for f in self.functions], find_stor_masks))

        def cleanup(exp):

            if m := match(exp, ("field", 0, ("stor", ("length", ":idx")))):
                return ("stor", ("length", m.idx))

            if m := match(exp, ("type", 256, ("field", 0,
                                              ("stor", ("length", ":idx"))))):
                return ("stor", ("length", m.idx))
Esempio n. 4
0
            def find_default(exp):

                if (m := match(exp, ("if", ":cond", ":if_true", ":if_false"))) and str(
                    ("cd", 0)
                ) in str(m.cond):
                    if find_f_list(m.if_false, func_calls) == []:
                        fi = m.if_false[0]
                        if m2 := match(fi, ("jd", ":jd")):
                            return int(m2.jd)
Esempio n. 5
0
    def make_params(self):
        """
            figures out parameter types from the decompiled function code.

            does so by looking at all 'cd'/calldata occurences and figuring out
            how they are accessed - are they masked? are they used as pointers?

        """

        params = get_func_params(self.hash)
        if len(params) > 0:
            res = {}
            idx = 4
            for p in params:
                res[idx] = (p["type"], p["name"])
                idx += 32
        else:
            # good testing: solidstamp, auditContract
            # try to find all the references to parameters and guess their types

            def f(exp):
                if match(exp, ("mask_shl", Any, Any, Any,
                               ("cd", Any))) or match(exp, ("cd", Any)):
                    return [exp]
                return []

            occurences = find_f_list(self.trace, f)

            sizes = {}
            for o in occurences:
                if m := match(o,
                              ("mask_shl", ":size", Any, Any, ("cd", ":idx"))):
                    size, idx = m.size, m.idx

                if m := match(o, ("cd", ":idx")):
                    idx = m.idx
                    size = 256

                if idx == 0:
                    continue

                if m := match(idx, ("add", 4, ("cd", ":in_idx"))):
                    # this is a mark of 'cd' being used as a pointer
                    sizes[m.in_idx] = -1
                    continue
Esempio n. 6
0
        in theory this is incorrect, because perhaps program does something totally different
        in the one branch, andd something entirely different in another.
        but this cleans up tremendous amounts of output, and didn't find a counterexample yet.
    '''

#    if line ~ ('setmem', ('range', :s, ('mask_shl', 251, 5, 0, ('add', 31, ('cd', ('add', 4, :param))))), ('data', ('call.data', ('add', 36, param), ('cd', ('add', 4, param))), ('mem', ...))):
#        lin = ('setmem', ('range', s, ('cd', ('add', 4, param))), ('call.data', ('add', 36, param), ('cd', ('add', 4, param))))
#        return [lin]

    if line ~ ('if', ('iszero', ('storage', 5, 0, :l)), :if_true, :if_false):
        def find_arr_l(exp):
            if exp ~ ('arr', ('storage', 256, 0, l), ...):
                return [exp]

        true_arr = find_f_list(if_true, find_arr_l)
        false_arr = find_f_list(if_true, find_arr_l)

        if len(true_arr) > 0 and len(true_arr) == len(false_arr):
            return if_true


    if line ~ ('if', ('iszero', ('mask_shl', 5, 0, 0, :l)), :if_true, :if_false):
        def find_arr_l(exp):
            if exp ~ ('arr', l, ...):
                return [exp]
        true_arr = find_f_list(if_true, find_arr_l)
        false_arr = find_f_list(if_true, find_arr_l)

        if len(true_arr) > 0 and len(true_arr) == len(false_arr):
            return if_true
Esempio n. 7
0
        '''
            we need to do ast creation from the contract, not function level,
            because some simplifications (type/field removal) require insight to all the functions,
            not just a single one
        '''

        for func in self.functions:
            func.ast = self.make_ast(func.trace)

        def find_stor_masks(exp):
            if exp ~ ('type', ...):
                return [exp]
            else:
                return []

        mlist = set(find_f_list([f.ast for f in self.functions], find_stor_masks))

        def cleanup(exp):

            if exp ~ ('field', 0, ('stor', ('length', :idx))):
                return ('stor', ('length', idx))

            if exp ~ ('type', 256, ('field', 0, ('stor', ('length', :idx)))):
                return ('stor', ('length', idx))

            if exp ~ ('type', 256, ('stor', ('length', :idx))):
                return ('stor', ('length', idx))

            if exp ~ ('type', :e_type, ('field', :e_field, ('stor', ('name', :e_name, :loc)))):
                for m in mlist:
                    if get_name(m) == e_name:
Esempio n. 8
0
        try:
            # decompiles the code, starting from location 0
            # and running VM in a special mode that returns 'funccall'
            # in places where it looks like there is a func call

            trace = vm.run(0)

            def func_calls(exp):
                if m := match(exp,
                              ("funccall", ":fx_hash", ":target", ":stack")):
                    return [(m.fx_hash, m.target, m.stack)]
                else:
                    return []

            func_list = find_f_list(trace, func_calls)

            for fx_hash, target, stack in func_list:
                self.add_func(target=target, hash=fx_hash, stack=stack)

            # find default

            def find_default(exp):

                if (m := match(
                        exp,
                    ("if", ":cond", ":if_true", ":if_false"))) and str(
                        ("cd", 0)) in str(m.cond):
                    if find_f_list(m.if_false, func_calls) == []:
                        fi = m.if_false[0]
                        if m2 := match(fi, ("jd", ":jd")):
Esempio n. 9
0
            raise
        line = trace[0]
        trace = trace[1:]

        if (line ~ ('if', :cond, :if_true, :if_false)):
            if is_revert(if_true):
                path.append(('require', is_zero(cond)))
                trace = if_false
                continue

            if is_revert(if_false):
                path.append(('require', cond))
                trace = if_true
                continue

            jds_true = find_f_list(if_true, get_jds)
            jds_false = find_f_list(if_false, get_jds)

            assert (jd in jds_true) != (jd in jds_false), (jds_true, jds_false)

            def add_path(line):
                if line ~ ('goto', _, :svs):
                    path2 = path
                    for _, v_idx, v_val in svs:
                        path2 = replace(path2, ('var', v_idx), v_val)

                    return path2 + [line]
                else:
                    return [line]

            if jd in jds_true:
Esempio n. 10
0
    def analyse(self):
        assert len(self.trace) > 0

        def find_returns(exp):
            if opcode(exp) == 'return':
                return [exp]
            else:
                return []

        exp_text = []

        self.returns = find_f_list(self.trace, find_returns)

        exp_text.append(('possible return values', prettify(self.returns)))

        first = self.trace[0]

        if opcode(first) == 'if' and simplify_bool(first[1]) == 'callvalue'   \
                and (first[2][0] == ('revert', 0) or opcode(first[2][0]) == 'invalid'):
            self.trace = self.trace[0][3]
            self.payable = False
        elif opcode(first) == 'if' and simplify_bool(first[1]) == ('iszero', 'callvalue')   \
                and (first[3][0] == ('revert', 0) or opcode(first[3][0]) == 'invalid'):
            self.trace = self.trace[0][2]
            self.payable = False
        else:
            self.payable = True

        exp_text.append(('payable', self.payable))

        self.read_only = True
        for op in ['store', 'selfdestruct', 'call', 'delegatecall', 'codecall', 'create']:
            if f"'{op}'" in str(self.trace):
                self.read_only = False

        exp_text.append(('read_only', self.read_only))


        '''
            const func detection
        '''

        self.const = self.read_only
        for exp in ['storage', 'calldata', 'calldataload', 'store', 'cd']:
            if exp in str(self.trace) or len(self.returns)!=1:
                self.const = False

        if self.const:
            self.const = self.returns[0]
            if len(self.const) == 3 and opcode(self.const[2]) == 'data':
                self.const = self.const[2]
            if len(self.const) == 3 and opcode(self.const[2]) == 'mask_shl':
                self.const = self.const[2]
            if len(self.const) == 3 and type(self.const[2]) == int:
                self.const = self.const[2]
        else:
            self.const = None

        if self.const:
            exp_text.append(('const', self.const))

        '''
            getter detection
        '''

        self.getter = None
        self.simplify_string_getter_from_storage()
        if self.const is None and \
           self.read_only and \
           len(self.returns) == 1:
                ret = self.returns[0][1]
                if ret ~ ('bool', ('storage', _, _, :loc)):
 def find_default(exp):
     if exp ~ ('if', :cond, :if_true, :if_false) and str(('cd', 0)) in str(cond):
         if find_f_list(if_false, func_calls) == []:
             fi = if_false[0]
             if fi ~ ('jd', :jd):
Esempio n. 12
0
    def analyse(self):
        assert len(self.trace) > 0

        def find_returns(exp):
            if opcode(exp) == "return":
                return [exp]
            else:
                return []

        exp_text = []

        self.returns = find_f_list(self.trace, find_returns)

        exp_text.append(("possible return values", prettify(self.returns)))

        first = self.trace[0]

        if (opcode(first) == "if" and simplify_bool(first[1]) == "callvalue"
                and (first[2][0] == ("revert", 0)
                     or opcode(first[2][0]) == "invalid")):
            self.trace = self.trace[0][3]
            self.payable = False
        elif (opcode(first) == "if"
              and simplify_bool(first[1]) == ("iszero", "callvalue")
              and (first[3][0] == ("revert", 0)
                   or opcode(first[3][0]) == "invalid")):
            self.trace = self.trace[0][2]
            self.payable = False
        else:
            self.payable = True

        exp_text.append(("payable", self.payable))

        self.read_only = True
        for op in [
                "store",
                "selfdestruct",
                "call",
                "delegatecall",
                "codecall",
                "create",
        ]:
            if f"'{op}'" in str(self.trace):
                self.read_only = False

        exp_text.append(("read_only", self.read_only))
        """
            const func detection
        """

        self.const = self.read_only
        for exp in ["storage", "calldata", "calldataload", "store", "cd"]:
            if exp in str(self.trace) or len(self.returns) != 1:
                self.const = False

        if self.const:
            self.const = self.returns[0]
            if len(self.const) == 3 and opcode(self.const[2]) == "data":
                self.const = self.const[2]
            if len(self.const) == 3 and opcode(self.const[2]) == "mask_shl":
                self.const = self.const[2]
            if len(self.const) == 3 and type(self.const[2]) == int:
                self.const = self.const[2]
        else:
            self.const = None

        if self.const:
            exp_text.append(("const", self.const))
        """
            getter detection
        """

        self.getter = None
        self.simplify_string_getter_from_storage()
        if self.const is None and self.read_only and len(self.returns) == 1:
            ret = self.returns[0][1]
            if match(ret, ("bool", ("storage", Any, Any, ":loc"))):
                self.getter = (
                    ret  # we have to be careful when using this for naming purposes,
                )
                # because sometimes the storage can refer to array length

            elif opcode(ret) == "mask_shl" and opcode(ret[4]) == "storage":
                self.getter = ret[4]
            elif opcode(ret) == "storage":
                self.getter = ret
            elif opcode(ret) == "data":
                terms = ret[1:]
                # for structs, we check if all the parts of the struct are storage from the same
                # location. if so, we return the location number

                t0 = terms[
                    0]  # 0xFAFfea71A6da719D6CAfCF7F52eA04Eb643F6De2 - documents
                if m := match(t0, ("storage", 256, 0, ":loc")):
                    loc = m.loc
                    for e in terms[1:]:
                        if not match(e,
                                     ("storage", 256, 0, ("add", Any, loc))):
                            break
                    else:
                        self.getter = t0

                # kitties getKitten - with more cases this and the above could be uniformed
                if self.getter is None:
                    prev_loc = -1
                    for e in terms:

                        def l2(x):
                            if m := match(x, ("sha3", ("data", Any, ":l"))):
                                if type(m.l) == int and m.l < 1000:
                                    return m.l
                            if (opcode(x) == "sha3" and type(x[1]) == int
                                    and x[1] < 1000):
                                return x[1]
                            return None

                        loc = find_f(e, l2)
                        if not loc or (prev_loc != -1 and prev_loc != loc):
                            break
                        prev_loc = loc

                    else:
                        self.getter = ("struct", ("loc", loc))