Exemplo n.º 1
0
    def read(self, addr, size):
        if size < 1:
            raise utils.CodeError("size < 1 in Memory.read")
        if addr + size >= MEMORY_SIZE:
            raise utils.CodeError("overflow in Memory.read")
        logger.debug("%s.read(%i, %i)" % (self.__class__.__name__, addr, size))

        for iaddr, ivalue in self._mem.items():
            isize = ivalue.size() // 8
            raddr = iaddr - addr
            rend = iaddr + isize - addr
            # completely overlaps (or equals)
            if raddr <= 0 and rend >= size:
                return _slice(ivalue, -raddr, -raddr + size)
            # completely inside (strictly)
            elif raddr > 0 and rend < size:
                return self.read(addr, raddr).concat(
                    self.read(addr + raddr, size - raddr)
                )
            # end inside
            elif 0 < rend < size:
                return _slice(ivalue, isize - rend, None).concat(
                    self.read(addr + rend, size - rend)
                )
            # start inside
            elif 0 < raddr < size:
                return self.read(addr, raddr).concat(_slice(ivalue, 0, size - raddr))

        assert addr not in self._mem
        self._mem[addr] = self._default(addr, size)
        return self._mem[addr]
Exemplo n.º 2
0
 def copy_from(self, other, start_self, start_other, size):
     assert size >= 0
     if MEMORY_SIZE and (
         start_self + size >= MEMORY_SIZE or start_other + size >= MEMORY_SIZE
     ):
         raise utils.CodeError("Memory.copy_from: memory would exceed MEMORY_SIZE")
     if size == 0:
         return
     self.write(start_self, size, CalldataMemoryView(other, start_other, size))
Exemplo n.º 3
0
    def write(self, addr, size, value):
        if size < 1:
            raise utils.CodeError("size < 1 in Memory.write")
        if addr + size >= MEMORY_SIZE:
            raise utils.CodeError("overflow in Memory.write")
        if value.size() // 8 != size:
            raise utils.InterpreterError("BVV size doesn't match size in Memory.write")

        logger.debug(
            "%s.write(%i, %i, %r)" % (self.__class__.__name__, addr, size, value)
        )

        for iaddr, ivalue in list(self._mem.items()):
            isize = ivalue.size() // 8
            raddr = iaddr - addr
            rend = iaddr + isize - addr
            # equal
            if raddr == 0 and rend == size:
                break
            # completely overlaps
            elif raddr <= 0 and rend >= size:
                if raddr < 0:
                    self._mem[iaddr] = _slice(ivalue, 0, -raddr)
                self._mem[addr + size] = _slice(ivalue, -raddr + size, None)
            # completely inside (not strictly)
            elif raddr >= 0 and rend <= size:
                del self._mem[iaddr]
            # end inside
            elif 0 < rend < size:
                self._mem[iaddr] = _slice(ivalue, 0, -raddr)
            # start inside
            elif 0 < raddr < size:
                del self._mem[iaddr]
                self._mem[addr + size] = _slice(ivalue, size - raddr, None)

        self._mem[addr] = value
Exemplo n.º 4
0
    def write(self, addr, size, value):
        assert size >= 0
        assert value.size(
        ) // 8 == size, "BVV size doesn't match size in Memory.write"
        if MEMORY_SIZE and addr + size >= MEMORY_SIZE:
            raise utils.CodeError(
                "Memory.write: memory would exceed MEMORY_SIZE")

        logger.debug("%s.write(%i, %i, %r)" %
                     (self.__class__.__name__, addr, size, value))

        if size == 0:
            return

        for iaddr, ivalue in list(self._mem.items()):
            isize = ivalue.size() // 8
            raddr = iaddr - addr
            rend = raddr + isize
            # equal
            if raddr == 0 and rend == size:
                break
            # completely overlaps
            elif raddr <= 0 and rend >= size:
                if raddr < 0:
                    self._mem[iaddr] = _slice(ivalue, 0, -raddr)
                if rend > size:
                    self._mem[addr + size] = _slice(ivalue, -raddr + size,
                                                    None)
            # completely inside (not strictly)
            elif raddr >= 0 and rend <= size:
                del self._mem[iaddr]
            # end inside
            elif 0 < rend < size:
                self._mem[iaddr] = _slice(ivalue, 0, -raddr)
            # start inside
            elif 0 < raddr < size:
                del self._mem[iaddr]
                self._mem[addr + size] = _slice(ivalue, size - raddr, None)

        self._mem[addr] = value
Exemplo n.º 5
0
 def copy_from(self, other, start_self, start_other, size):
     if size < 1:
         raise utils.CodeError("size < 1 in Memory.copy_from")
     if start_self + size >= MEMORY_SIZE or start_other + size >= MEMORY_SIZE:
         raise utils.CodeError("overflow in Memory.copy_from")
     self.write(start_self, size, CalldataMemoryView(other, start_other, size))
Exemplo n.º 6
0
 def stack_pop(self):
     if not self.stack:
         raise utils.CodeError("Stack underflow")
     return self.stack.pop()
Exemplo n.º 7
0
 def stack_push(self, x):
     if len(self.stack) >= 1024:
         raise utils.CodeError("Stack overflow")
     self.stack.append(x)
Exemplo n.º 8
0
    def exec_branch(self, state):  # pylint:disable=invalid-name
        """Execute forward from a state, queuing new states if needed."""
        logger.debug("Constraints: %s", state.solver.constraints)

        def solution(variable):
            """Returns the solution. There must be one or we fail."""
            solutions = state.solver.eval(variable, 2)
            if len(solutions) > 1:
                raise MultipleSolutionsError(
                    "Multiple solutions for %s (%#x)" %
                    (variable, self.code[state.pc]))
            solution = solutions[0]
            return solution if isinstance(solution,
                                          numbers.Number) else solution.value

        self.code.pc = state.pc

        while True:
            if state.pc >= len(self.code):
                return True

            op = self.code[state.pc]
            self.code.pc += 1
            self.coverage[state.pc] += 1

            logger.debug("NEW STEP")
            logger.debug("Memory: %s", state.memory)
            logger.debug("Stack: %s", state.stack)
            logger.debug("PC: %i, %#x", state.pc, op)

            assert self.code.pc == state.pc + 1
            assert isinstance(op, numbers.Number)
            assert all(
                isinstance(i, claripy.ast.base.BV) for i in
                state.stack), "The stack musty only contains claripy BV's"

            # Trivial operations first
            if not self.code.is_valid_opcode(state.pc):
                raise utils.CodeError("Trying to execute PUSH data")
            elif op == 254:  # INVALID opcode
                raise utils.CodeError("designed INVALID opcode")
            elif op == opcode_values.JUMPDEST:
                pass
            elif op == opcode_values.ADD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 + s1)
            elif op == opcode_values.SUB:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 - s1)
            elif op == opcode_values.MUL:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 * s1)
            elif op == opcode_values.DIV:
                # We need to use claripy.LShR instead of a division if possible,
                # because the solver is bad dealing with divisions, better
                # with shifts. And we need shifts to handle the solidity ABI
                # for function selection.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)  # pylint:disable=invalid-name
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 / s1))
                else:
                    if s1 == 0:
                        state.stack_push(BVV_0)
                    elif s1 == 1:
                        state.stack_push(s0)
                    elif s1 & (s1 - 1) == 0:
                        exp = int(math.log(s1, 2))
                        state.stack_push(s0.LShR(exp))
                    else:
                        state.stack_push(s0 / s1)
            elif op == opcode_values.SDIV:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SDiv(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SDiv(s1))
            elif op == opcode_values.MOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 % s1))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0 % s1)
            elif op == opcode_values.SMOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SMod(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SMod(s1))
            elif op == opcode_values.ADDMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 + s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 + s1) % s2)
            elif op == opcode_values.MULMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except MultipleSolutionsError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 * s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 * s1) % s2)
            elif op == opcode_values.SHL:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(value << shift)
            elif op == opcode_values.SHR:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(value.LShR(shift))
            elif op == opcode_values.SAR:
                shift, value = state.stack_pop(), state.stack_pop()
                state.stack_push(claripy.RotateRight(value, shift))
            elif op == opcode_values.EXP:
                base, exponent = state.stack_pop(), state.stack_pop()
                base_sol = solution(base)
                if base_sol == 2:
                    state.stack_push(1 << exponent)
                else:
                    try:
                        exponent_sol = solution(exponent)
                    except MultipleSolutionsError:
                        state.stack_push(exponent)  # restore stack
                        state.stack_push(base)
                        self.add_for_fuzzing(state, exponent,
                                             EXP_EXPONENT_FUZZ)
                        return False
                    else:
                        state.stack_push(
                            claripy.BVV(base_sol**exponent_sol, 256))
            elif op == opcode_values.LT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.ULT(s0, s1)))
            elif op == opcode_values.GT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.UGT(s0, s1)))
            elif op == opcode_values.SLT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.SLT(s0, s1)))
            elif op == opcode_values.SGT:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(bool_to_bv(claripy.SGT(s0, s1)))
            elif op == opcode_values.SIGNEXTEND:
                # TODO: Use Claripy's SignExt that should do exactly that.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                # s0 is the number of bits. s1 the number we want to extend.
                s0 = solution(s0)
                if s0 <= 31:
                    sign_bit = 1 << (s0 * 8 + 7)
                    state.stack_push(
                        claripy.If(
                            s1 & sign_bit == 0,
                            s1 & (sign_bit - 1),
                            s1 | ((1 << 256) - sign_bit),
                        ))
                else:
                    state.stack_push(s1)
            elif op == opcode_values.EQ:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(bool_to_bv(s0 == s1))
            elif op == opcode_values.ISZERO:
                state.stack_push(bool_to_bv(state.stack_pop() == BVV_0))
            elif op == opcode_values.AND:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 & s1)
            elif op == opcode_values.OR:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 | s1)
            elif op == opcode_values.XOR:
                s0, s1 = state.stack_pop(), state.stack_pop()
                state.stack_push(s0 ^ s1)
            elif op == opcode_values.NOT:
                state.stack_push(~state.stack_pop())
            elif op == opcode_values.BYTE:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(
                    s1.LShR(claripy.If(s0 > 31, 32, 31 - s0) * 8) & 0xFF)

            elif op == opcode_values.PC:
                state.stack_push(bvv(state.pc))
            elif op == opcode_values.GAS:
                state.stack_push(state.env.gas)
            elif op == opcode_values.ADDRESS:
                state.stack_push(state.env.address)
            elif op == opcode_values.BALANCE:
                addr = solution(state.stack_pop())
                if addr != solution(state.env.address):
                    raise utils.InterpreterError(
                        state,
                        "Can only query balance of the current contract for now"
                    )
                state.stack_push(state.env.balance)
            elif op == opcode_values.ORIGIN:
                state.stack_push(state.env.origin)
            elif op == opcode_values.CALLER:
                state.stack_push(state.env.caller)
            elif op == opcode_values.CALLVALUE:
                state.stack_push(state.env.value)
            elif op == opcode_values.BLOCKHASH:
                block_num = state.stack_pop()
                if block_num not in state.env.block_hashes:
                    state.env.block_hashes[block_num] = claripy.BVS(
                        "blockhash[%s]" % block_num, 256)
                state.stack_push(state.env.block_hashes[block_num])
            elif op == opcode_values.TIMESTAMP:
                state.stack_push(state.env.block_timestamp)
            elif op == opcode_values.NUMBER:
                state.stack_push(state.env.block_number)
            elif op == opcode_values.COINBASE:
                state.stack_push(state.env.coinbase)
            elif op == opcode_values.DIFFICULTY:
                state.stack_push(state.env.difficulty)
            elif op == opcode_values.POP:
                state.stack_pop()
            elif op == opcode_values.JUMP:
                addr = solution(state.stack_pop())
                if addr >= len(self.code
                               ) or self.code[addr] != opcode_values.JUMPDEST:
                    raise utils.CodeError("Invalid jump (%i)" % addr)
                state.pc = addr
                self.add_branch(state)
                return False
            elif op == opcode_values.JUMPI:
                addr, condition = solution(
                    state.stack_pop()), state.stack_pop()
                state_false = state.copy()
                state.solver.add(condition != BVV_0)
                state_false.solver.add(condition == BVV_0)
                state_false.pc += 1
                self.add_branch(state_false)
                state.pc = addr
                if (state.pc >= len(self.code)
                        or self.code[state.pc] != opcode_values.JUMPDEST):
                    raise utils.CodeError("Invalid jump (%i)" % (state.pc - 1))
                self.add_branch(state)
                return False
            elif opcode_values.PUSH1 <= op <= opcode_values.PUSH32:
                pushnum = op - opcode_values.PUSH1 + 1
                raw_value = self.code.read(pushnum)
                state.pc += pushnum
                state.stack_push(
                    bvv(int.from_bytes(raw_value, byteorder="big")))
            elif opcode_values.DUP1 <= op <= opcode_values.DUP16:
                depth = op - opcode_values.DUP1 + 1
                state.stack_push(state.stack[-depth])
            elif opcode_values.SWAP1 <= op <= opcode_values.SWAP16:
                depth = op - opcode_values.SWAP1 + 1
                temp = state.stack[-depth - 1]
                state.stack[-depth - 1] = state.stack[-1]
                state.stack[-1] = temp
            elif opcode_values.LOG0 <= op <= opcode_values.LOG4:
                depth = op - opcode_values.LOG0
                mstart, msz = (state.stack_pop(), state.stack_pop())
                topics = [state.stack_pop() for x in range(depth)]
            elif op == opcode_values.SHA3:
                start, length = solution(state.stack_pop()), solution(
                    state.stack_pop())
                memory = state.memory.read(start, length)
                state.stack_push(Sha3(memory))
            elif op == opcode_values.STOP:
                return True
            elif op == opcode_values.RETURN:
                return True

            elif op == opcode_values.CALLDATALOAD:
                index = state.stack_pop()
                try:
                    index_sol = solution(index)
                except MultipleSolutionsError:
                    state.stack_push(index)  # restore the stack
                    self.add_for_fuzzing(state, index, CALLDATALOAD_INDEX_FUZZ)
                    return False
                state.stack_push(state.env.calldata.read(index_sol, 32))
            elif op == opcode_values.CALLDATASIZE:
                state.stack_push(state.env.calldata_size)
            elif op == opcode_values.CALLDATACOPY:
                old_state = state.copy()
                mstart, dstart, size = (
                    state.stack_pop(),
                    state.stack_pop(),
                    state.stack_pop(),
                )
                mstart, dstart = solution(mstart), solution(dstart)
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    self.add_for_fuzzing(old_state, size,
                                         CALLDATACOPY_SIZE_FUZZ)
                    return False
                state.memory.copy_from(state.env.calldata, mstart, dstart,
                                       size)
            elif op == opcode_values.CODESIZE:
                state.stack_push(bvv(len(self.code)))
            elif op == opcode_values.EXTCODESIZE:
                addr = state.stack_pop()
                if (addr == state.env.address).is_true():
                    state.stack_push(bvv(len(self.code)))
                else:
                    # TODO: Improve that... It's clearly not constraining enough.
                    state.stack_push(claripy.BVS("EXTCODESIZE[%s]" % addr,
                                                 256))

            elif op == opcode_values.EXTCODECOPY:
                old_state = state.copy()
                addr = state.stack_pop()
                mem_start = solution(state.stack_pop())
                code_start = solution(state.stack_pop())

                size = state.stack_pop()
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    # TODO: Fuzz.
                    # self.add_for_fuzzing(old_state, size, [])
                    # return False
                    raise
                state.memory.write(
                    mem_start,
                    size,
                    claripy.BVS("EXTCODE[%s from %s]" % (addr, code_start),
                                size * 8),
                )

            elif op == opcode_values.CODECOPY:
                mem_start, code_start, size = [
                    solution(state.stack_pop()) for _ in range(3)
                ]
                for i in range(size):
                    if code_start + i < len(state.env.code):
                        state.memory.write(
                            mem_start + i,
                            1,
                            claripy.BVV(state.env.code[code_start + i], 8),
                        )
                    else:
                        state.memory.write(mem_start + i, 1, claripy.BVV(0, 8))

            elif op == opcode_values.MLOAD:
                index = solution(state.stack_pop())
                state.stack_push(state.memory.read(index, 32))
            elif op == opcode_values.MSTORE:
                index, value = solution(state.stack_pop()), state.stack_pop()
                state.memory.write(index, 32, value)
            elif op == opcode_values.MSTORE8:
                index, value = solution(state.stack_pop()), state.stack_pop()
                state.memory.write(index, 1, value[7:0])
            elif op == opcode_values.MSIZE:
                state.stack_push(bvv(state.memory.size()))

            elif op == opcode_values.SLOAD:
                state.pc += 1
                key = state.stack_pop()
                for w_key, w_value in state.storage_written.items():
                    read_written = [w_key == key]
                    if state.solver.satisfiable(
                            extra_constraints=read_written):
                        new_state = state.copy()
                        new_state.solver.add(read_written)
                        new_state.stack_push(w_value)
                        self.add_branch(new_state)
                    state.solver.add(w_key != key)
                if state.solver.satisfiable():
                    assert key not in state.storage_written
                    if key not in state.storage_read:
                        state.storage_read[key] = claripy.BVS(
                            "storage[%s]" % key, 256)
                    state.stack_push(state.storage_read[key])
                    self.add_branch(state)
                return

            elif op == opcode_values.SSTORE:
                state.pc += 1
                key = state.stack_pop()
                value = state.stack_pop()
                for w_key, w_value in state.storage_written.items():
                    read_written = [w_key == key]
                    if state.solver.satisfiable(
                            extra_constraints=read_written):
                        new_state = state.copy()
                        new_state.solver.add(read_written)
                        new_state.storage_written[w_key] = value
                        self.add_branch(new_state)
                        state.solver.add(w_key != key)
                if state.solver.satisfiable():
                    assert key not in state.storage_written
                    state.storage_written[key] = value
                    self.add_branch(state)
                return

            elif op == opcode_values.CALL:
                state.pc += 1

                # pylint:disable=unused-variable
                gas, to_, value, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(7))

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(BVV_0)
                self.add_branch(state_fail)

                # Second possibility: success.
                state.calls.append((memoutsz, memoutstart, meminsz, meminstart,
                                    value, to_, gas))

                memoutsz = solution(memoutsz)
                if memoutsz != 0:
                    # If we expect some output, let's constraint the call to
                    # be to a contract that we do control. Otherwise it could
                    # return anything...
                    state.solver.add(to_[159:0] == utils.DEFAULT_CALLER[159:0])

                    memoutstart = solution(memoutstart)
                    state.memory.write(
                        memoutstart,
                        memoutsz,
                        claripy.BVS("CALL_RETURN[%s]" % to_, memoutsz * 8),
                    )

                state.stack_push(BVV_1)
                self.add_branch(state)
                return False

            elif op == opcode_values.DELEGATECALL:
                state.pc += 1

                # pylint:disable=unused-variable
                gas, to_, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(6))

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(BVV_0)
                self.add_branch(state_fail)

                # If the call is to a specific contract we don't control,
                # don't assume it could return anything, or even be successful.
                # So we say we need to be able to call an arbitrary contract.
                state.solver.add(to_[159:0] == utils.DEFAULT_CALLER[159:0])

                # Second possibility: success.
                state.calls.append(
                    (memoutsz, memoutstart, meminsz, meminstart, to_, gas))

                memoutsz = solution(memoutsz)
                if memoutsz != 0:
                    memoutstart = solution(memoutstart)
                    state.memory.write(
                        memoutstart,
                        memoutsz,
                        claripy.BVS("DELEGATECALL_RETURN[%s]" % to_,
                                    memoutsz * 8),
                    )

                state.stack_push(BVV_1)
                self.add_branch(state)
                return False

            elif op == opcode_values.RETURNDATASIZE:
                state.stack_push(claripy.BVS("RETURNDATASIZE", 256))

            elif op == opcode_values.RETURNDATACOPY:
                old_state = state.copy()
                mem_start_position = solution(state.stack_pop())
                returndata_start_position = solution(state.stack_pop())

                size = state.stack_pop()
                try:
                    size = solution(size)
                except MultipleSolutionsError:
                    self.add_for_fuzzing(old_state, size,
                                         RETURNDATACOPY_SIZE_FUZZ)
                    return False

                state.memory.write(mem_start_position, size,
                                   claripy.BVS("RETURNDATACOPY", size * 8))

            elif op == opcode_values.SELFDESTRUCT:
                state.selfdestruct_to = state.stack[-1]
                return True

            elif op == opcode_values.REVERT:
                return False

            else:
                raise utils.InterpreterError(state, "Unknown opcode %#x" % op)

            state.pc += 1
Exemplo n.º 9
0
    def exec_branch(self, state):  # pylint:disable=invalid-name
        """Execute forward from a state, queuing new states if needed."""
        logger.debug("Constraints: %s", state.solver.constraints)

        def solution(variable):
            """Returns the solution. There must be one or we fail."""
            solutions = state.solver.eval(variable, 2)
            if len(solutions) > 1:
                raise ValueError("Ambiguous solution for %s (%s)" %
                                 (variable, self.code[state.pc]))
            solution = solutions[0]
            return solution if isinstance(solution,
                                          numbers.Number) else solution.value

        state.score += 1
        self.code.pc = state.pc

        while True:
            if state.pc >= len(self.code):
                return True

            op = self.code.next()
            self.coverage[state.pc] += 1

            logger.debug("NEW STEP")
            logger.debug("Memory: %s", state.memory)
            logger.debug("Stack: %s", state.stack)
            logger.debug("PC: %i, %s", state.pc, op)

            assert self.code.pc == state.pc + 1
            assert isinstance(op, numbers.Number)
            assert all(
                hasattr(i, "symbolic") for i in
                state.stack), "The stack musty only contains claripy BV's"

            # Trivial operations first
            if not self.code.is_valid_opcode(state.pc):
                raise utils.CodeError("Trying to execute PUSH data")
            elif op == 254:  # INVALID opcode
                raise utils.CodeError("designed INVALID opcode")
            elif op == opcode_values.JUMPDEST:
                pass
            elif op == opcode_values.ADD:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 + s1)
            elif op == opcode_values.SUB:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 - s1)
            elif op == opcode_values.MUL:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(s0 * s1)
            elif op == opcode_values.DIV:
                # We need to use claripy.LShR instead of a division if possible,
                # because the solver is bad dealing with divisions, better
                # with shifts. And we need shifts to handle the solidity ABI
                # for function selection.
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)  # pylint:disable=invalid-name
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 / s1))
                else:
                    if s1 == 0:
                        state.stack_push(BVV_0)
                    elif s1 == 1:
                        state.stack_push(s0)
                    elif s1 & (s1 - 1) == 0:
                        exp = int(math.log(s1, 2))
                        state.stack_push(s0.LShR(exp))
                    else:
                        state.stack_push(s0 / s1)
            elif op == opcode_values.SDIV:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SDiv(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SDiv(s1))
            elif op == opcode_values.MOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0 % s1))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0 % s1)
            elif op == opcode_values.SMOD:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                try:
                    s1 = solution(s1)
                except ValueError:
                    state.stack_push(claripy.If(s1 == 0, BVV_0, s0.SMod(s1)))
                else:
                    state.stack_push(BVV_0 if s1 == 0 else s0.SMod(s1))
            elif op == opcode_values.ADDMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except ValueError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 + s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 + s1) % s2)
            elif op == opcode_values.MULMOD:
                s0, s1, s2 = state.stack_pop(), state.stack_pop(
                ), state.stack_pop()
                try:
                    s2 = solution(s2)
                except ValueError:
                    state.stack_push(claripy.If(s2 == 0, BVV_0,
                                                (s0 * s1) % s2))
                else:
                    state.stack_push(BVV_0 if s2 == 0 else (s0 * s1) % s2)
            elif op == opcode_values.EXP:
                base, exponent = solution(state.stack_pop()), state.stack_pop()
                if base == 2:
                    state.stack_push(1 << exponent)
                else:
                    exponent = solution(exponent)
                    state.stack_push(claripy.BVV(base**exponent, 256))
            elif op == opcode_values.LT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.ULT(s0, s1))
            elif op == opcode_values.GT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.UGT(s0, s1))
            elif op == opcode_values.SLT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.SLT(s0, s1))
            elif op == opcode_values.SGT:
                s0, s1 = (
                    not_bool(state.stack_pop()),
                    not_bool(state.stack_pop()),
                )  # pylint:disable=invalid-name
                state.stack_push(claripy.SGT(s0, s1))
            elif op == opcode_values.SIGNEXTEND:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                # s0 is the number of bits. s1 the number we want to extend.
                s0 = solution(s0)
                if s0 <= 31:
                    sign_bit = 1 << (s0 * 8 + 7)
                    state.stack_push(
                        claripy.If(
                            s1 & sign_bit == 0,
                            s1 & (sign_bit - 1),
                            s1 | ((1 << 256) - sign_bit),
                        ))
                else:
                    state.stack_push(s1)
            elif op == opcode_values.EQ:
                s0, s1 = state.stack_pop(), state.stack_pop()
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 == s1)
                else:
                    state.stack_push(not_bool(s0) == not_bool(s1))
            elif op == opcode_values.ISZERO:
                condition = state.stack_pop()
                if isinstance(condition, claripy.ast.Bool):
                    state.stack_push(claripy.Not(condition))
                else:
                    state.stack_push(condition == BVV_0)
            elif op == opcode_values.AND:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 and s1)
                else:
                    state.stack_push(s0 & s1)
            elif op == opcode_values.OR:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                if isinstance(s0, claripy.ast.Bool) and isinstance(
                        s1, claripy.ast.Bool):
                    state.stack_push(s0 or s1)
                else:
                    state.stack_push(s0 | s1)
            elif op == opcode_values.XOR:
                s0, s1 = make_consistent(state.stack_pop(), state.stack_pop())
                state.stack_push(s0 ^ s1)
            elif op == opcode_values.NOT:
                state.stack_push(~state.stack_pop())
            elif op == opcode_values.BYTE:
                s0, s1 = (
                    state.stack_pop(),
                    state.stack_pop(),
                )  # pylint:disable=invalid-name
                state.stack_push(
                    s1.LShR(claripy.If(s0 > 31, 32, 31 - s0) * 8) & 0xFF)

            elif op == opcode_values.PC:
                state.stack_push(bvv(state.pc))
            elif op == opcode_values.GAS:
                state.stack_push(state.env.gas)
            elif op == opcode_values.ADDRESS:
                state.stack_push(state.env.address)
            elif op == opcode_values.BALANCE:
                addr = solution(state.stack_pop())
                if addr != solution(state.env.address):
                    raise utils.InterpreterError(
                        state,
                        "Can only query balance of the current contract for now"
                    )
                state.stack_push(state.env.balance)
            elif op == opcode_values.ORIGIN:
                state.stack_push(state.env.origin)
            elif op == opcode_values.CALLER:
                state.stack_push(state.env.caller)
            elif op == opcode_values.CALLVALUE:
                state.stack_push(state.env.value)
            elif op == opcode_values.BLOCKHASH:
                block_num = state.stack_pop()
                if block_num not in state.env.block_hashes:
                    state.env.block_hashes[block_num] = claripy.BVS(
                        "blockhash[%s]" % block_num, 256)
                state.stack_push(state.env.block_hashes[block_num])
            elif op == opcode_values.TIMESTAMP:
                state.stack_push(state.env.block_timestamp)
            elif op == opcode_values.NUMBER:
                state.stack_push(state.env.block_number)
            elif op == opcode_values.COINBASE:
                state.stack_push(state.env.coinbase)
            elif op == opcode_values.DIFFICULTY:
                state.stack_push(state.env.difficulty)
            elif op == opcode_values.POP:
                state.stack_pop()
            elif op == opcode_values.JUMP:
                addr = solution(state.stack_pop())
                if addr >= len(self.code
                               ) or self.code[addr] != opcode_values.JUMPDEST:
                    raise utils.CodeError("Invalid jump (%i)" % addr)
                state.pc = addr
                self.add_branch(state)
                return False
            elif op == opcode_values.JUMPI:
                addr, condition = solution(
                    state.stack_pop()), state.stack_pop()
                state_false = state.copy()
                if isinstance(condition, claripy.ast.Bool):
                    state.solver.add(condition)
                    state_false.solver.add(claripy.Not(condition))
                else:
                    state.solver.add(condition != 0)
                    state_false.solver.add(condition == 0)
                state_false.pc += 1
                self.add_branch(state_false)
                state.pc = addr
                if (state.pc >= len(self.code)
                        or self.code[state.pc] != opcode_values.JUMPDEST):
                    raise utils.CodeError("Invalid jump (%i)" % (state.pc - 1))
                self.add_branch(state)
                return False
            elif opcode_values.PUSH1 <= op <= opcode_values.PUSH32:
                pushnum = op - opcode_values.PUSH1 + 1
                raw_value = self.code.read(pushnum)
                state.pc += pushnum
                state.stack_push(
                    bvv(int.from_bytes(raw_value, byteorder="big")))
            elif opcode_values.DUP1 <= op <= opcode_values.DUP16:
                depth = op - opcode_values.DUP1 + 1
                state.stack_push(state.stack[-depth])
            elif opcode_values.SWAP1 <= op <= opcode_values.SWAP16:
                depth = op - opcode_values.SWAP1 + 1
                temp = state.stack[-depth - 1]
                state.stack[-depth - 1] = state.stack[-1]
                state.stack[-1] = temp
            elif opcode_values.LOG0 <= op <= opcode_values.LOG4:
                depth = op - opcode_values.LOG0
                mstart, msz = (state.stack_pop(), state.stack_pop())
                topics = [state.stack_pop() for x in range(depth)]
            elif op == opcode_values.SHA3:
                start, length = solution(state.stack_pop()), solution(
                    state.stack_pop())
                memory = state.memory.read(start, length)
                state.stack_push(Sha3(memory))
            elif op == opcode_values.STOP:
                return True
            elif op == opcode_values.RETURN:
                return True

            elif op == opcode_values.CALLDATALOAD:
                indexes = state.stack_pop()
                try:
                    index = solution(indexes)
                except ValueError:  # Multiple solutions, let's fuzz.
                    state.stack_push(indexes)  # restore the stack
                    self.add_for_fuzzing(state, indexes, CALLDATASIZE_FUZZ)
                    return False
                state.solver.add(state.env.calldata_size >= index + 32)
                state.stack_push(state.env.calldata.read(index, 32))
            elif op == opcode_values.CALLDATASIZE:
                state.stack_push(state.env.calldata_size)
            elif op == opcode_values.CALLDATACOPY:
                old_state = state.copy()
                mstart, dstart, size = (
                    state.stack_pop(),
                    state.stack_pop(),
                    state.stack_pop(),
                )
                mstart, dstart = solution(mstart), solution(dstart)
                try:
                    size = solution(size)
                except ValueError:
                    self.add_for_fuzzing(old_state, size, CALLDATASIZE_FUZZ)
                    return False
                state.memory.copy_from(state.env.calldata, mstart, dstart,
                                       size)
                state.solver.add(state.env.calldata_size >= dstart + size)
            elif op == opcode_values.CODESIZE:
                state.stack_push(bvv(len(self.code)))
            elif op == opcode_values.EXTCODESIZE:
                addr = state.stack_pop()
                if (addr == state.env.address).is_true():
                    state.stack_push(bvv(len(self.code)))
                else:
                    # TODO: Improve that... It's clearly not constraining enough.
                    state.stack_push(claripy.BVS("EXTCODESIZE[%s]" % addr,
                                                 256))
            elif op == opcode_values.CODECOPY:
                mem_start, code_start, size = [
                    solution(state.stack_pop()) for _ in range(3)
                ]
                for i in range(size):
                    if code_start + i < len(state.env.code):
                        state.memory.write(
                            mem_start + i,
                            1,
                            claripy.BVV(state.env.code[code_start + i], 8),
                        )
                    else:
                        state.memory.write(mem_start + i, 1, claripy.BVV(0, 8))

            elif op == opcode_values.MLOAD:
                index = solution(state.stack_pop())
                state.stack_push(state.memory.read(index, 32))
            elif op == opcode_values.MSTORE:
                index, value = solution(state.stack_pop()), not_bool(
                    state.stack_pop())
                state.memory.write(index, 32, value)
            elif op == opcode_values.MSTORE8:
                index, value = solution(state.stack_pop()), not_bool(
                    state.stack_pop())
                state.memory.write(index, 1, value[7:0])
            elif op == opcode_values.MSIZE:
                state.stack_push(bvv(state.memory.size()))
            elif op == opcode_values.SLOAD:
                # TODO: This is inaccurate, because the storage can change
                # in a single transaction.
                # See commit d98cab834f8f359f01ef805256d179f5529ebe30.
                key = state.stack_pop()
                if key in state.storage_written:
                    state.stack_push(state.storage_written[key])
                else:
                    if key not in state.storage_read:
                        state.storage_read[key] = claripy.BVS(
                            "storage[%s]" % key, 256)
                    state.stack_push(state.storage_read[key])
            elif op == opcode_values.SSTORE:
                # TODO: This is inaccurate, because the storage can change
                # in a single transaction.
                # See commit d98cab834f8f359f01ef805256d179f5529ebe30.
                key = state.stack_pop()
                value = state.stack_pop()
                state.storage_written[key] = value

            elif op == opcode_values.CALL:
                state.pc += 1

                # First possibility: the call fails
                # (always possible with a call stack big enough)
                state_fail = state.copy()
                state_fail.stack_push(claripy.BoolV(False))
                self.add_branch(state_fail)

                # Second possibility: success.
                state.calls.append(state.stack[-7:])

                # pylint:disable=unused-variable
                gas, to_, value, meminstart, meminsz, memoutstart, memoutsz = (
                    state.stack_pop() for _ in range(7))

                if solution(memoutsz) != 0:
                    raise utils.InterpreterError(state,
                                                 "CALL seems to return data")
                if solution(meminsz) != 0:
                    raise utils.InterpreterError(state,
                                                 "CALL seems to take data")

                state.stack_push(claripy.BoolV(True))
                self.add_branch(state)
                return False

            elif op == opcode_values.SELFDESTRUCT:
                state.selfdestruct_to = state.stack[-1]
                return True

            elif op == opcode_values.REVERT:
                return False
            else:
                raise utils.InterpreterError(state, "Unknown opcode %s" % op)

            state.pc += 1