コード例 #1
0
def run_symbolic(program,
                 path,
                 code=None,
                 state=None,
                 ctx=None,
                 inclusive=False):
    MAX_CALLDATA_SIZE = 256
    xid = gen_exec_id()
    state = state or SymbolicEVMState(xid=xid, code=code)
    storage = state.storage
    constraints = []
    sha_constraints = dict()
    ctx = ctx or dict()
    min_timestamp = (datetime.datetime.now() -
                     datetime.datetime(1970, 1, 1)).total_seconds()
    # make sure we can exploit it in the foreseable future
    max_timestamp = (datetime.datetime(2020, 1, 1) -
                     datetime.datetime(1970, 1, 1)).total_seconds()
    ctx['CODESIZE-ADDRESS'] = len(code)
    calldata = z3.Array('CALLDATA_%d' % xid, z3.BitVecSort(256),
                        z3.BitVecSort(8))
    calldatasize = z3.BitVec('CALLDATASIZE_%d' % xid, 256)
    instruction_count = 0
    state.balance += ctx_or_symbolic('CALLVALUE', ctx, xid)

    target_op = program[path[-1]].name

    while state.pc in program:
        state.trace.append(state.pc)
        instruction_count += 1

        # have we reached the end of our path?
        if ((inclusive and len(path) == 0)
                or (not inclusive and path == [state.pc])):
            state.success = True
            return SymbolicResult(xid, state, constraints, sha_constraints,
                                  target_op)

        # if not, have we reached another step of our path?
        elif state.pc == path[0]:
            path = path[1:]

        ins = program[state.pc]
        opcode = ins.op
        op = ins.name
        stk = state.stack
        mem = state.memory
        state.gas -= ins.gas
        # Valid operations
        # Pushes first because they are very frequent
        if 0x60 <= opcode <= 0x7f:
            stk.append(int.from_bytes(ins.arg, byteorder='big'))
            state.pc += opcode - 0x5f  # Move 1 byte forward for 0x60, up to 32 bytes for 0x7f
        # Arithmetic
        elif opcode < 0x10:
            if op == 'STOP':
                if path:
                    raise IntractablePath()
                state.success = True
                return SymbolicResult(xid, state, constraints, sha_constraints)
            elif op == 'ADD':
                stk.append(stk.pop() + stk.pop())
            elif op == 'SUB':
                stk.append(stk.pop() - stk.pop())
            elif op == 'MUL':
                stk.append(stk.pop() * stk.pop())
            elif op == 'DIV':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s1):
                    stk.append(0 if s1 == 0 else s0 /
                               s1 if concrete(s0) else z3.UDiv(s0, s1))
                else:
                    stk.append(
                        z3.If(s1 == 0, z3.BitVecVal(0, 256), z3.UDiv(s0, s1)))
            elif op == 'MOD':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s1):
                    stk.append(0 if s1 == 0 else s0 % s1)
                else:
                    stk.append(
                        z3.If(s1 == 0, z3.BitVecVal(0, 256), z3.URem(s0, s1)))
            elif op == 'SDIV':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    s0, s1 = teether.util.utils.to_signed(
                        s0), teether.util.utils.to_signed(s1)
                    stk.append(0 if s1 == 0 else abs(s0) // abs(s1) *
                               (-1 if s0 * s1 < 0 else 1))
                elif concrete(s1):
                    stk.append(0 if s1 == 0 else s0 / s1)
                else:
                    stk.append(z3.If(s1 == 0, z3.BitVecVal(0, 256), s0 / s1))
            elif op == 'SMOD':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    s0, s1 = teether.util.utils.to_signed(
                        s0), teether.util.utils.to_signed(s1)
                    stk.append(0 if s1 == 0 else abs(s0) % abs(s1) *
                               (-1 if s0 < 0 else 1))
                elif concrete(s1):
                    stk.append(0 if s1 == 0 else z3.SRem(s0, s1))
                else:
                    stk.append(
                        z3.If(s1 == 0, z3.BitVecVal(0, 256), z3.SRem(s0, s1)))
            elif op == 'ADDMOD':
                s0, s1, s2 = stk.pop(), stk.pop(), stk.pop()
                if concrete(s2):
                    stk.append((s0 + s1) % s2 if s2 else 0)
                else:
                    stk.append(
                        z3.If(s2 == 0, z3.BitVecVal(0, 256),
                              z3.URem((s0 + s1), s2)))
            elif op == 'MULMOD':
                s0, s1, s2 = stk.pop(), stk.pop(), stk.pop()
                if concrete(s2):
                    stk.append((s0 * s1) % s2 if s2 else 0)
                else:
                    stk.append(
                        z3.If(s2 == 0, z3.BitVecVal(0, 256),
                              z3.URem((s0 * s1), s2)))
            elif op == 'EXP':
                base, exponent = stk.pop(), stk.pop()
                if concrete(base) and concrete(exponent):
                    stk.append(pow(base, exponent, teether.util.utils.TT256))
                else:
                    if concrete(base) and teether.util.utils.is_pow2(base):
                        l2 = teether.util.utils.log2(base)
                        stk.append(1 << (l2 * exponent))
                    else:
                        raise SymbolicError(
                            'exponentiation with symbolic exponent currently not supported :-/'
                        )
            elif op == 'SIGNEXTEND':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    if s0 <= 31:
                        testbit = s0 * 8 + 7
                        if s1 & (1 << testbit):
                            stk.append(s1 | (teether.util.utils.TT256 -
                                             (1 << testbit)))
                        else:
                            stk.append(s1 & ((1 << testbit) - 1))
                    else:
                        stk.append(s1)
                elif concrete(s0):
                    if s0 <= 31:
                        oldwidth = (s0 + 1) * 8
                        stk.append(z3.SignExt(256 - oldwidth, s1))
                    else:
                        stk.append(s1)
                else:
                    raise SymbolicError(
                        'symbolic bitwidth for signextension is currently not supported'
                    )
        # Comparisons
        elif opcode < 0x20:
            if op == 'LT':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    stk.append(1 if s0 < s1 else 0)
                else:
                    stk.append(
                        z3.If(z3.ULT(s0, s1), z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'GT':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    stk.append(1 if s0 > s1 else 0)
                else:
                    stk.append(
                        z3.If(z3.UGT(s0, s1), z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'SLT':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    s0, s1 = teether.util.utils.to_signed(
                        s0), teether.util.utils.to_signed(s1)
                    stk.append(1 if s0 < s1 else 0)
                else:
                    stk.append(
                        z3.If(s0 < s1, z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'SGT':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    s0, s1 = teether.util.utils.to_signed(
                        s0), teether.util.utils.to_signed(s1)
                    stk.append(1 if s0 > s1 else 0)
                else:
                    stk.append(
                        z3.If(s0 > s1, z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'EQ':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0) and concrete(s1):
                    stk.append(1 if s0 == s1 else 0)
                else:
                    stk.append(
                        z3.If(s0 == s1, z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'ISZERO':
                s0 = stk.pop()
                if concrete(s0):
                    stk.append(1 if s0 == 0 else 0)
                else:
                    stk.append(
                        z3.If(s0 == 0, z3.BitVecVal(1, 256),
                              z3.BitVecVal(0, 256)))
            elif op == 'AND':
                stk.append(stk.pop() & stk.pop())
            elif op == 'OR':
                stk.append(stk.pop() | stk.pop())
            elif op == 'XOR':
                stk.append(stk.pop() ^ stk.pop())
            elif op == 'NOT':
                stk.append(~stk.pop())
            elif op == 'BYTE':
                s0, s1 = stk.pop(), stk.pop()
                if concrete(s0):
                    if s0 >= 32:
                        stk.append(0)
                    else:
                        if concrete(s1):
                            stk.append((s1 // 256**(31 - s0)) % 256)
                        else:
                            v = z3.simplify(
                                z3.Extract((31 - s0) * 8 + 7, (31 - s0) * 8,
                                           s1))
                            if z3.is_bv_value(v):
                                stk.append(v.as_long())
                            else:
                                stk.append(z3.ZeroExt(256 - 32, v))
                else:
                    raise SymbolicError('symbolic byte-index not supported')
        # SHA3 and environment info
        elif opcode < 0x40:
            if op == 'SHA3':
                s0, s1 = stk.pop(), stk.pop()
                mem.extend(s0, s1)
                mm = mem.read(s0, s1)
                if not isinstance(mm, SymRead) and all(
                        concrete(m) for m in mm):
                    data = teether.util.utils.bytearray_to_bytestr(mm)
                    stk.append(
                        teether.util.utils.big_endian_to_int(
                            teether.util.utils.sha3(data)))
                else:
                    if not isinstance(mm, SymRead):
                        sha_data = z3.simplify(
                            z3.Concat([
                                m if z3.is_expr(m) else z3.BitVecVal(m, 8)
                                for m in mm
                            ]))
                        for k, v in sha_constraints.items():
                            if isinstance(v, SymRead):
                                continue
                            if v.size() == sha_data.size() and is_true(
                                    v == sha_data):
                                sha = k
                                break
                        else:
                            sha = z3.BitVec(
                                'SHA3_%x_%d' % (instruction_count, xid), 256)
                            sha_constraints[sha] = sha_data
                    else:
                        sha_data = mm
                        sha = z3.BitVec(
                            'SHA3_%x_%d' % (instruction_count, xid), 256)
                        sha_constraints[sha] = sha_data
                    stk.append(sha)
                    # raise SymbolicError('symbolic computation of SHA3 not supported')
            elif op == 'ADDRESS':
                stk.append(ctx_or_symbolic('ADDRESS', ctx, xid))
            elif op == 'BALANCE':
                s0 = stk.pop()
                if concrete(s0):
                    stk.append(ctx_or_symbolic('BALANCE-%x' % s0, ctx, xid))
                elif is_true(
                        addr(s0) == addr(ctx_or_symbolic('ADDRESS', ctx,
                                                         xid))):
                    stk.append(state.balance)
                elif is_true(
                        addr(s0) == addr(ctx_or_symbolic('CALLER', ctx, xid))):
                    stk.append(ctx_or_symbolic('BALANCE-CALLER', ctx, xid))
                else:
                    raise SymbolicError('balance of symbolic address (%s)' %
                                        str(z3.simplify(s0)))
            elif op == 'ORIGIN':
                stk.append(ctx_or_symbolic('ORIGIN', ctx, xid))
            elif op == 'CALLER':
                stk.append(ctx_or_symbolic('CALLER', ctx, xid))
            elif op == 'CALLVALUE':
                stk.append(ctx_or_symbolic('CALLVALUE', ctx, xid))
            elif op == 'CALLDATALOAD':
                s0 = stk.pop()
                constraints.append(z3.UGE(calldatasize, s0 + 32))
                if not concrete(s0):
                    constraints.append(z3.ULT(s0, MAX_CALLDATA_SIZE))
                stk.append(z3.Concat([calldata[s0 + i] for i in range(32)]))
            elif op == 'CALLDATASIZE':
                stk.append(calldatasize)
            elif op == 'CALLDATACOPY':
                mstart, dstart, size = stk.pop(), stk.pop(), stk.pop()
                constraints.append(z3.UGE(calldatasize, dstart + size))
                if not concrete(dstart):
                    constraints.append(z3.ULT(dstart, MAX_CALLDATA_SIZE))
                if concrete(size):
                    for i in range(size):
                        mem[mstart + i] = calldata[dstart + i]
                else:
                    constraints.append(z3.ULT(size, MAX_CALLDATA_SIZE))
                    for i in range(MAX_CALLDATA_SIZE):
                        mem[mstart + i] = z3.If(size < i, mem[mstart + i],
                                                calldata[dstart + i])
            elif op == 'CODESIZE':
                stk.append(len(state.code))
            elif op == 'CODECOPY':
                mstart, dstart, size = stk.pop(), stk.pop(), stk.pop()
                if concrete(mstart) and concrete(dstart) and concrete(size):
                    mem.extend(mstart, size)
                    for i in range(size):
                        if dstart + i < len(state.code):
                            mem[mstart + i] = state.code[dstart + i]
                        else:
                            mem[mstart + i] = 0
                else:
                    raise SymbolicError('Symbolic code index @ %s' % ins)
            elif op == 'RETURNDATACOPY':
                raise ExternalData('RETURNDATACOPY')
            elif op == 'RETURNDATASIZE':
                raise ExternalData('RETURNDATASIZE')
            elif op == 'GASPRICE':
                stk.append(ctx_or_symbolic('GASPRICE', ctx, xid))
            elif op == 'EXTCODESIZE':
                s0 = stk.pop()
                if concrete(s0):
                    stk.append(ctx_or_symbolic('CODESIZE-%x' % s0, ctx, xid))
                elif is_true(s0 == addr(ctx_or_symbolic('ADDRESS', ctx, xid))):
                    stk.append(ctx_or_symbolic('CODESIZE-ADDRESS', ctx, xid))
                elif is_true(s0 == addr(ctx_or_symbolic('CALLER', ctx, xid))):
                    stk.append(ctx_or_symbolic('CODESIZE-CALLER', ctx, xid))
                else:
                    raise SymbolicError('codesize of symblic address')
            elif op == 'EXTCODECOPY':
                raise ExternalData('EXTCODECOPY')
        # Block info
        elif opcode < 0x50:
            if op == 'BLOCKHASH':
                s0 = stk.pop()
                if not concrete(s0):
                    raise SymbolicError('symbolic blockhash index')
                stk.append(ctx_or_symbolic('BLOCKHASH[%d]' % s0, xid))
            elif op == 'COINBASE':
                stk.append(ctx_or_symbolic('COINBASE', ctx, xid))
            elif op == 'TIMESTAMP':
                ts = ctx_or_symbolic('TIMESTAMP', ctx, xid)
                if not concrete(ts):
                    constraints.append(z3.UGE(ts, min_timestamp))
                    constraints.append(z3.ULE(ts, max_timestamp))
                stk.append(ts)
            elif op == 'NUMBER':
                stk.append(ctx_or_symbolic('NUMBER', ctx, xid))
            elif op == 'DIFFICULTY':
                stk.append(ctx_or_symbolic('DIFFICULTY', ctx, xid))
            elif op == 'GASLIMIT':
                stk.append(ctx_or_symbolic('GASLIMIT', ctx, xid))
        # VM state manipulations
        elif opcode < 0x60:
            if op == 'POP':
                stk.pop()
            elif op == 'MLOAD':
                s0 = stk.pop()
                mem.extend(s0, 32)
                mm = [mem[s0 + i] for i in range(32)]
                if all(concrete(m) for m in mm):
                    stk.append(
                        teether.util.utils.bytes_to_int(mem.read(s0, 32)))
                else:
                    v = z3.simplify(
                        z3.Concat([
                            m if not concrete(m) else z3.BitVecVal(m, 8)
                            for m in mm
                        ]))
                    if z3.is_bv_value(v):
                        stk.append(v.as_long())
                    else:
                        stk.append(v)
            elif op == 'MSTORE':
                s0, s1 = stk.pop(), stk.pop()
                mem.extend(s0, 32)
                if concrete(s1):
                    mem.write(s0, 32, teether.util.utils.encode_int32(s1))
                else:
                    for i in range(32):
                        m = z3.simplify(
                            z3.Extract((31 - i) * 8 + 7, (31 - i) * 8, s1))
                        if z3.is_bv_value(m):
                            mem[s0 + i] = m.as_long()
                        else:
                            mem[s0 + i] = m
            elif op == 'MSTORE8':
                s0, s1 = stk.pop(), stk.pop()
                mem.extend(s0, 1)
                mem[s0] = s1 % 256
            elif op == 'SLOAD':
                s0 = stk.pop()

                v = z3.simplify(storage[s0])
                if z3.is_bv_value(v):
                    stk.append(v.as_long())
                else:
                    stk.append(v)
            elif op == 'SSTORE':
                s0, s1 = stk.pop(), stk.pop()
                storage[s0] = s1
            elif op == 'JUMP':
                s0 = stk.pop()
                if not concrete(s0):
                    raise SymbolicError('Symbolic jump target')
                state.pc = s0
                if state.pc >= len(state.code) or not program[
                        state.pc].name == 'JUMPDEST':
                    raise VMException('BAD JUMPDEST')
                continue
            elif op == 'JUMPI':
                s0, s1 = stk.pop(), stk.pop()
                next_target = path[0]
                if concrete(s1):
                    if s1:
                        if not concrete(s0):
                            raise SymbolicError('Symbolic jump target')
                        if s0 != next_target and state.pc + 1 == next_target:
                            raise IntractablePath(state.trace, path)
                        state.pc = s0
                        if state.pc >= len(state.code) or not program[
                                state.pc].name == 'JUMPDEST':
                            raise VMException('BAD JUMPDEST')
                        continue
                    else:
                        if concrete(s0):
                            if state.pc + 1 != next_target and s0 == next_target:
                                raise IntractablePath(state.trace, path)
                else:
                    if state.pc + 1 == next_target:
                        if not (concrete(s0) and s0 == next_target):
                            constraints.append(s1 == 0)
                    elif concrete(s0) and s0 == next_target:
                        if state.pc + 1 != next_target:
                            constraints.append(s1 != 0)
                        state.pc = s0
                        if state.pc >= len(state.code) or not program[
                                state.pc].name == 'JUMPDEST':
                            raise VMException('BAD JUMPDEST')
                        continue
                    elif not concrete(s0):
                        raise SymbolicError('Symbolic jump target')
                    else:
                        raise IntractablePath(state.trace, path)

            elif op == 'PC':
                stk.append(state.pc)
            elif op == 'MSIZE':
                stk.append(len(mem))
            elif op == 'GAS':
                stk.append(z3.BitVec('GAS_%x' % instruction_count, 256))
        # DUPn (eg. DUP1: a b c -> a b c c, DUP3: a b c -> a b c a)
        elif op[:3] == 'DUP':
            stk.append(
                stk[0x7f - opcode]
            )  # 0x7f - opcode is a negative number, -1 for 0x80 ... -16 for 0x8f
        # SWAPn (eg. SWAP1: a b c d -> a b d c, SWAP3: a b c d -> d b c a)
        elif op[:4] == 'SWAP':
            # 0x8e - opcode is a negative number, -2 for 0x90 ... -17 for 0x9f
            stk[-1], stk[0x8e - opcode] = stk[0x8e - opcode], stk[-1]
        # Logs (aka "events")
        elif op[:3] == 'LOG':
            """
            0xa0 ... 0xa4, 32/64/96/128/160 + len(data) gas
            a. Opcodes LOG0...LOG4 are added, takes 2-6 stack arguments
                    MEMSTART MEMSZ (TOPIC1) (TOPIC2) (TOPIC3) (TOPIC4)
            b. Logs are kept track of during tx execution exactly the same way as selfdestructs
               (except as an ordered list, not a set).
               Each log is in the form [address, [topic1, ... ], data] where:
               * address is what the ADDRESS opcode would output
               * data is mem[MEMSTART: MEMSTART + MEMSZ]
               * topics are as provided by the opcode
            c. The ordered list of logs in the transaction are expressed as [log0, log1, ..., logN].
            """
            depth = int(op[3:])
            mstart, msz = stk.pop(), stk.pop()
            topics = [stk.pop() for _ in range(depth)]
            mem.extend(mstart, msz)
            # Ignore external effects...
        # Create a new contract
        elif op == 'CREATE':
            s0, s1, s2 = stk.pop(), stk.pop(), stk.pop()
            constraints.append(z3.UGE(state.balance, s0))
            state.balance -= s0
            stk.append(
                addr(
                    z3.BitVec('EXT_CREATE_%d_%d' % (instruction_count, xid),
                              256)))
        # Calls
        elif op in ('CALL', 'CALLCODE', 'DELEGATECALL', 'STATICCALL'):
            if op in ('CALL', 'CALLCODE'):
                s0, s1, s2, s3, s4, s5, s6 = stk.pop(), stk.pop(), stk.pop(
                ), stk.pop(), stk.pop(), stk.pop(), stk.pop()
                if op == 'CALL':
                    constraints.append(z3.UGE(state.balance, s2))
                    state.balance -= s2
            elif op == 'DELEGATECALL':
                s0, s1, s3, s4, s5, s6 = stk.pop(), stk.pop(), stk.pop(
                ), stk.pop(), stk.pop(), stk.pop()
                s2 = ctx_or_symbolic('CALLVALUE', ctx, xid)
            elif op == 'STATICCALL':
                s0, s1, s3, s4, s5, s6 = stk.pop(), stk.pop(), stk.pop(
                ), stk.pop(), stk.pop(), stk.pop()
                s2 = 0

            ostart = s5 if concrete(s5) else z3.simplify(s5)
            olen = s6 if concrete(s6) else z3.simplify(s6)

            if concrete(s1) and s1 <= 8:
                if s1 == 4:
                    logging.info("Calling precompiled identity contract")
                    istart = s3 if concrete(s3) else z3.simplify(s3)
                    ilen = s4 if concrete(s4) else z3.simplify(s4)
                    mem.copy(istart, ilen, ostart, olen)
                    stk.append(1)
                else:
                    raise SymbolicError(
                        "Precompiled contract %d not implemented" % s1)
            else:
                for i in range(olen):
                    mem[ostart + i] = z3.BitVec(
                        'EXT_%d_%d_%d' % (instruction_count, i, xid), 8)
                logging.info("Calling contract %s (%d_%d)" %
                             (s1, instruction_count, xid))
                stk.append(
                    z3.BitVec('CALLRESULT_%d_%d' % (instruction_count, xid),
                              256))

        elif op == 'RETURN':
            s0, s1 = stk.pop(), stk.pop()
            if concrete(s0) and concrete(s1):
                mem.extend(s0, s1)
            state.success = True
            if path:
                raise IntractablePath(state.trace, path)
            return SymbolicResult(xid, state, constraints, sha_constraints)
        # Revert opcode (Metropolis)
        elif op == 'REVERT':
            s0, s1 = stk.pop(), stk.pop()
            if not concrete(s0) or not concrete(s1):
                raise SymbolicError('symbolic memory index')
            mem.extend(s0, s1)
            if path:
                raise IntractablePath(state.trace, path)
            return SymbolicResult(xid, state, constraints, sha_constraints)
        # SELFDESTRUCT opcode (also called SELFDESTRUCT)
        elif op == 'SELFDESTRUCT':
            s0 = stk.pop()
            state.success = True
            if path:
                raise IntractablePath(state.trace, path)
            return SymbolicResult(xid, state, constraints, sha_constraints)

        state.pc += 1

    if path:
        raise IntractablePath(state.trace, path)
    state.success = True
    return SymbolicResult(xid, state, constraints, sha_constraints)
コード例 #2
0
class Z3Context(Context):
    def __init__(self, *args, **kw):
        Context.__init__(self, *args, **kw)
        self.solver = z3.Solver()

    def __getitem__(self, key):
        if not isinstance(key, Sort) and key in self.storage:
            return self.storage[key]
        elif isinstance(key, Sort):
            if key.name in self.storage:
                return self.storage[key.name]
            val = self.new_from_sort(key)
            self.storage[key.name] = val
            return val
        else:
            raise ValueError("%s not found! %s. %s." %(key, type(key), self.storage))

    def new_from_sort(self, key):
        if isinstance(key, Bool):
            key = key.name
            val = z3.Bool(key)
            return val
        elif isinstance(key, Int):
            key = key.name
            val = z3.Int(key)
            return val
        elif isinstance(key, String):
            key = key.name
            val = z3.String(key)
            return val
        elif isinstance(key, BitVec):
            name = key.name
            size = key.size
            val = z3.BitVec(name, size)
            return val
        raise TypeError("%s not supported!" %type(key))

    def s_assert(self, expr):
        self.solver.assert_exprs(expr.cval(self))

    def s_check(self):
        res = self.solver.check()
        return res

    def s_model(self):
        try:
            m = self.solver.model()
            return self.process_model(m)
        except z3.Z3Exception:
            return {}

    def s_push(self):
        self.solver.push()

    def s_pop(self):
        self.solver.pop()

    def s_reset(self):
        self.solver.reset()

    def solve(self, AST):
        outputs = []
        self.s_reset()
        for node in AST:
            if isinstance(node, Sort):
                self.s_assert(node)
            elif isinstance(node, Let):
                self.s_assert(node.term)
            elif isinstance(node, Command):
                if node.cname == "push":
                    self.s_push()
                elif node.cname == "pop":
                    self.s_pop()
                elif node.cname == "check-sat":
                    logger.info("\n-------")
                    outputs.append(self.s_check())
                    logger.info("Check: %s" % outputs[-1])
                elif node.cname == "get-model":
                    outputs.append(self.s_model())
                    logger.info("Model: %s" % outputs[-1])
            else:
                raise ValueError("Command %s not supported!" %node)
        
        return outputs

    def process_model(self, z3_model):
        m = {}
        for v in z3_model:
            m[v.name()] = self.get_py_value(z3_model.get_interp(v))
        return m

    def get_py_value(self, assignment):
        if z3.is_ast(assignment):
            if z3.is_int_value(assignment):
                return assignment.as_long()
            if z3.is_bool(assignment):
                return z3.is_true(assignment)
            if z3.is_string_value(assignment):
                try:
                    val = assignment.as_string()[1:-1]  # remove quotes
                    val = val.replace("\\x00", "")
                    return str(val)
                # Z3 throws encoding errors. It can't decode its own solution..
                # TODO find a better fix.
                except UnicodeDecodeError:
                    val = assignment.as_ast()
                    return repr(val)
            raise ValueError("Unsupported Z3 type! %s" % type(assignment))
        
        return assignment

    BoolVal = lambda self, x : z3.BoolVal(x)
    StringVal = lambda self, x : z3.StringVal(x)
    IntVal = lambda self, x : z3.IntVal(x)
    BitVecVal = lambda self, val, size : z3.BitVecVal(val, size)

    And = lambda self, *x : z3.And(x)
    Or = lambda self, *x : z3.Or(x)
    Xor = lambda self, *x : reduce(xor, x)
    Implies = lambda self, x, y : z3.Implies(x, y)
    Distinct = lambda self, x, y : z3.Distinct(x, y)

    def Eq(self, x, y):
        # x = z3.String("x")
        # x == "test" #throws an error. This is a workaround for now.
        x = z3.StringVal(x) if isinstance(x,str) else x
        y = z3.StringVal(y) if isinstance(y,str) else y
        return eq(x,y)

    Not = lambda self, x : z3.Not(x)

    If = lambda self, *x : z3.If(*x)

    add = lambda self, *x : reduce(add, x)
    sub = lambda self, *x : reduce(sub, x) if len(x) > 1 else -x[0]
    mul = lambda self, *x : reduce(mul, x)
    
    lt = lambda self, *x : reduce(lt, x)
    le = lambda self, *x : reduce(le, x)
    gt = lambda self, *x : reduce(gt, x)
    ge = lambda self, *x : reduce(ge, x)

    concat = lambda self, *x : reduce(add, x)
    length = lambda self, x : z3.Length(x)
    contains = lambda self, x, y : z3.Contains(x, y)
    indexof = lambda self, x, y, z=0 : z3.IndexOf(x, y, z)
    extract = lambda self, x, y, z : z3.Extract(x, y, z)

    bvadd = add
    bvsub = sub
    bvmul = mul
    bvxor = Xor
    bvneg = lambda self, x : neg(x)
    bvnot = lambda self, x : inv(x)
    bvconcat = lambda self, *x : z3.Concat(*x)
    bvlshr = lambda self, x, y : z3.LShR(x, y)
    bvlshl = lambda self, x, y : z3.LShL(x, y)
    bvuge = lambda self, x, y : z3.UGE(x, y)
    bvurem = lambda self, x, y : z3.URem(x, y)

    # TODO Need to define all these with stuff in computation folder
    FPAbs = lambda self, *x : None
    FPNeg = lambda self, *x : None
    FPAdd = lambda self, *x : None
    FPSub = lambda self, *x : None
    FPMul = lambda self, *x : None
    FPDiv = lambda self, *x : None
    FPFMA = lambda self, *x : None
    FPRem = lambda self, *x : None
    FPSqrt = lambda self, *x : None
    FPRoundToIntegral = lambda self, *x : None
    FPMin = lambda self, *x : None
    FPMax = lambda self, *x : None
    FPLEQ = lambda self, *x : None
    FPLT = lambda self, *x : None
    FPGEQ = lambda self, *x : None
    FPGT = lambda self, *x : None
    FPEQ = lambda self, *x : None
    FPIsNormal = lambda self, *x : None
    FPIsSubNormal = lambda self, *x : None
    FPIsZero = lambda self, *x : None
    FPIsInfinite = lambda self, *x : None
    FPIsNan = lambda self, *x : None
    FPIsNegative = lambda self, *x : None
    FPIsPositive = lambda self, *x : None