예제 #1
0
def sys_read(s, cc):
    f = cc(s)

    fd = f.params[0]
    buf = f.params[1]
    size = f.params[2]

    s.log.syscall(f, 'sys_read(fd={}, ptr={}, size={})', fd, buf, size)

    output = OutputBuffer(s, buf)

    if fd.symbolic:
        raise ValueError('wtf')

    if fd.value > len(s.files):
        return f.ret(value=0)
    else:
        file = s.files[fd.value]
        offset = file['offset']
        output = OutputBuffer(s, buf)

        real_fd = None
        if file['path'] not in ['stdin', 'stdout', 'stderr']:
            real_fd = open(file['path'], 'rb')

        if size.symbolic:
            raise NotImplementedError()
        elif real_fd is None:
            for i in xrange(0, size.value):
                b = bv.Symbol(8, 'file_{}_{:x}'.format(fd.value, offset))
                output.append(b)
                file['bytes'][offset] = b
                offset += 1
        else:
            real_fd.seek(offset, 0)
            for i in range(0, size.value):
                byte = real_fd.read(1)
                if len(byte) == 1:
                    if byte == '#':
                        b = bv.Symbol(8,
                                      'file_{}_{:x}'.format(fd.value, offset))
                    else:
                        b = bv.Constant(8, ord(byte))
                    output.append(b)
                    file['bytes'][offset] = b
                    offset += 1
                else:
                    break

        file['offset'] = offset

        if real_fd is not None:
            real_fd.close()

    return f.ret(value=size)
예제 #2
0
def fread(s, cc):
    f = cc(s)

    buf = f.params[0]
    size = f.params[1]
    count = f.params[2]
    stream = f.params[3]

    s.log.function_call(f, 'fread(ptr={}, size={}, count={}, stream={})', buf, size, count, stream)

    if stream.symbolic:
        raise ValueError('wtf')

    if stream.value > len(s.files):
        return f.ret(value=0)
    else:
        file = s.files[stream.value]
        offset = file['offset']
        output = OutputBuffer(s, buf)

        fd = None
        if file['path'] not in ['stdin', 'stdout', 'stderr']:
            fd = open(file['path'], 'rb')

        if size.symbolic or count.symbolic:
            raise NotImplementedError()
        elif fd is None:
            for i in xrange(0, size.value * count.value):
                output.append(bv.Symbol(8, 'file_{}_{:x}'.format(stream.value, offset)))
                offset += 1
        else:
            fd.seek(offset, 0)
            for i in range(0, size.value * count.value):
                byte = fd.read(1)
                if len(byte) == 1:
                    if byte == '#':
                        output.append(bv.Symbol(8, 'file_{}_{:x}'.format(stream.value, offset)))
                    else:
                        output.append(bv.Constant(8, ord(byte)))
                    offset += 1
                else:
                    break

        file['offset'] = offset

        if fd is not None:
            fd.close()

    return f.ret(value=output.index)
예제 #3
0
    def read(self, address, size):
        if arbitrary(self, address):
            raise ArbitraryRead(self, address)

        as_ = concretise(self, address)
        try:
            if len(as_) > 1:
                e = None
                value = bv.Symbol(size, unique_name('read'))

                for a in as_:
                    v = None
                    for i in range(0, size // 8):
                        if v is None:
                            v = self.memory.read_byte(self, a.value + i)
                        else:
                            v = self.memory.read_byte(self, a.value + i).concatenate(v)

                    if e is None:
                        e = (address == a) & (value == v)
                    else:
                        e = e | ((address == a) & (value == v))

                self.solver.add(e)
            else:
                value = self.memory.read_byte(self, as_[0].value)

                for i in range(1, size // 8):
                    value = self.memory.read_byte(self, as_[0].value + i).concatenate(value)
        except KeyError:
            raise InvalidRead(self, address)

        return value
예제 #4
0
def sys_fstat(s, cc):
    f = cc(s)

    fd = f.params[0]
    statbuf = f.params[1]

    s.log.syscall(f, 'sys_fstat(fd={}, statbuf={})', fd, statbuf)

    o = OutputBuffer(s, statbuf)
    for i in range(0, 64):  # sizeof(struct stat) == 64
        o.append(bv.Symbol(8, unique_name('fstat')))

    return f.ret(value=0)
예제 #5
0
    def ret(self, value=None):
        return_address = bv.Constant(64, self.state.ip)

        # set return value (if set)
        if value is not None:
            if isinstance(value, int):
                self.state.registers['rax'] = bv.Constant(64, value)
            elif isinstance(value, str):
                self.state.registers['rax'] = bv.Symbol(64, unique_name(value))
            else:
                self.state.registers['rax'] = value

        print 'returning to {}'.format(return_address)

        return self.state.branch(return_address)
예제 #6
0
    def ret(self, value=None):
        # load return address, adjust stack pointer
        esp = self.state.registers['esp']
        self.state.registers['esp'] = esp + bv.Constant(32, 4)
        return_address = self.state.read(esp, 32)

        # set return value (if set)
        if value is not None:
            if isinstance(value, int):
                self.state.registers['eax'] = bv.Constant(32, value)
            elif isinstance(value, str):
                self.state.registers['eax'] = bv.Symbol(32, unique_name(value))
            else:
                self.state.registers['eax'] = value

        return self.state.branch(return_address)
예제 #7
0
    def ret(self, value=None):
        # load return address, adjust stack pointer
        rsp = self.state.registers['rsp']
        self.state.registers['rsp'] = rsp + bv.Constant(64, 8)
        return_address = self.state.read(rsp, 64)

        # set return value (if set)
        if value is not None:
            if isinstance(value, int) or isinstance(value, long):
                self.state.registers['rax'] = bv.Constant(64, value)
            elif isinstance(value, str):
                self.state.registers['rax'] = bv.Symbol(64, unique_name(value))
            else:
                self.state.registers['rax'] = value

        return self.state.branch(return_address)
예제 #8
0
def concretise(state, value, count=8):
    values = set()
    constraint = None

    if not value.symbolic:
        return [value]
    elif not isinstance(value, bv.Symbol):
        new_value = bv.Symbol(value.size, unique_name('concretise'))
        constraint = (new_value == value)
        state.solver.add(constraint)
        value = new_value

    # we now know that value is a symbol

    # TODO: this really hurts performance, but it will probably also help
    # with finding bugs... add in again once I have better path culling
    # heuristics again

    #values.add(maximum(state, value))
    #values.add(minimum(state, value))

    #if len(values) == 1:
    # max == min, our work here is done...
    #    return list(values)

    while len(values) < count:
        if not state.solver.check(constraint):
            break

        m = state.solver.model(constraint)
        if value.name not in m:
            # solver doesn't know anything about our value yet...
            constraint = value != 0xbeefcafe
            continue

        model_value = m[value.name]
        values.add(model_value)

        if constraint is not None:
            constraint = bl.BinaryOperation(constraint, bl.BinaryOperator.And,
                                            value != model_value)
        else:
            constraint = value != model_value

    #print list(map(lambda x:hex(x.value), values))

    return list(values)
예제 #9
0
def sys_fstat64(s, cc):
    f = cc(s)

    fd = f.params[0]
    statbuf = f.params[1]

    s.log.syscall(f, 'sys_fstat64(fd={}, statbuf={})', fd, statbuf)

    o = OutputBuffer(s, statbuf)
    if not fd.symbolic and fd.value == 1:
        for c in '\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x90\x21\x00\x00\x01\x00\x00\x00\xe8\x03\x00\x00\x05\x00\x00\x00\x03\x88\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x50\x5d\x08\x54\x3d\xa7\x22\x1b\x50\x5d\x08\x54\x3d\xa7\x22\x1b\x43\x40\x08\x54\x3d\xa7\x22\x1b\x06\x00\x00\x00\x00\x00\x00\x00':
            o.append(c)
    else:
        for i in range(0, 96):  # sizeof(struct stat64) == 96
            o.append(bv.Symbol(8, unique_name('fstat64')))

    return f.ret(value=0)
예제 #10
0
def gets(s, cc):
    f = cc(s)

    buf = f.params[0]


    s.log.function_call(f, 'gets(buf={})', buf)

    output = OutputBuffer(s, buf)

    # TODO: this needs to use the file for stdin instead of this nonsense

    for i in xrange(0, 0x100000):
        byte = bv.Symbol(8, unique_name('stdin_{0}'.format(i)))
        s.files[0]['bytes'].append(byte)
        output.append(byte)

    s.files[0]['bytes'].append(bv.Constant(8, 0x0a))
    output.append(bv.Constant(8, 0))

    return f.ret(value=buf)
예제 #11
0
def strcmp(s, cc):
    f = cc(s)

    str1 = f.params[0]
    str2 = f.params[1]

    s.log.function_call(f, 'strcmp(str1={}, str2={})', str1, str2)

    iter1 = iter(String(s, str1))
    iter2 = iter(String(s, str2))

    first_smaller = bv.Constant(32, -1)
    first_larger = bv.Constant(32, 1)
    zero = bv.Constant(32, 0)

    characters = []
    not_terminated = None
    not_already_terminated = bl.Constant(True)
    while True:
        (char1, constraint1) = next(iter1)
        (char2, constraint2) = next(iter2)

        not_terminated = not_already_terminated & constraint1
        not_terminated = not_terminated & constraint2
        not_terminated = not_terminated & (char1 == char2)

        characters.append((not_already_terminated, char1, char2))

        not_already_terminated = not_terminated

        if ((not char1.symbolic and char1.value == 0)
            or (not char2.symbolic and char2.value == 0)):
            break

    characters.reverse()

    result = None
    prev_result = None
    for (not_already_terminated, char1, char2) in characters:
        if result is None:
            result = bv.if_then_else(
                        char1 == char2,
                        zero,
                        bv.if_then_else(
                            char1 < char2,
                            first_smaller,
                            first_larger))
        else:
            result = bv.if_then_else(
                        not_already_terminated,
                        bv.if_then_else(
                            char1 == char2,
                            prev_result,
                            bv.if_then_else(
                                char1 < char2,
                                first_smaller,
                                first_larger)),
                        prev_result)

        # this reduces the memory footprint_ of the resulting expression
        # significantly
        prev_result = bv.Symbol(32, unique_name('tmp'))
        s.solver.add(prev_result == result)

    if result.symbolic:
        result_symbol = bv.Symbol(32, unique_name('strcmp'))
        s.solver.add(result_symbol == result)
        result = result_symbol

    return f.ret(value=result)
예제 #12
0
def memcmp(s, cc):
    f = cc(s)

    ptr1 = f.params[0]
    ptr2 = f.params[1]
    num = f.params[2]

    s.log.function_call(f, 'memcmp(ptr1={}, ptr2={}, num={})', ptr1, ptr2, num)

    count = 0

    first_smaller = bv.Constant(ptr1.size, -1)
    first_larger = bv.Constant(ptr1.size, 1)
    zero = bv.Constant(ptr1.size, 0)

    bytes = []

    not_terminated = None
    not_already_terminated = bl.Constant(True)
    while s.solver.check(num > count):
        byte1 = s.read(ptr1 + bv.Constant(ptr1.size, count), 8)
        byte2 = s.read(ptr2 + bv.Constant(ptr2.size, count), 8)

        not_terminated = not_already_terminated & (byte1 == byte2)

        bytes.append((not_already_terminated, byte1, byte2))

        if not_terminated.symbolic:
            not_already_terminated = bl.Symbol(unique_name('tmp'))
            s.solver.add(not_already_terminated == not_terminated)
        else:
            not_already_terminated = not_terminated

        count += 1

    bytes.reverse()

    result = None
    prev_result = None
    for (not_already_terminated, byte1, byte2) in bytes:
        if result is None:
            result = bv.if_then_else(
                        byte1 == byte2,
                        zero,
                        bv.if_then_else(
                            byte1 < byte2,
                            first_smaller,
                            first_larger))
        else:
            result = bv.if_then_else(
                        not_already_terminated,
                        bv.if_then_else(
                            byte1 == byte2,
                            prev_result,
                            bv.if_then_else(
                                byte1 < byte2,
                                first_smaller,
                                first_larger)),
                        prev_result)

        # this reduces the memory footprint_ of the resulting expression
        # significantly
        prev_result = bv.Symbol(ptr1.size, unique_name('tmp'))
        s.solver.add(prev_result == result)

    if result.symbolic:
        result_symbol = bv.Symbol(result.size, unique_name('memcmp'))
        s.solver.add(result_symbol == result)
        result = result_symbol

    return f.ret(value=result)
예제 #13
0
def memchr(s, cc):
    f = cc(s)

    ptr = f.params[0]
    value = f.params[1].resize(8)
    num = f.params[2]

    if num.symbolic:
        num = maximum(s, num)

    s.log.function_call(f, 'memchr(ptr={}, value={}, num={})', ptr, value, num)

    if ptr.symbolic:
        ptrs = concretise(s, ptr)
    else:
        ptrs = [ptr]

    ss = []
    total_ptrs = len(ptrs)
    while len(ptrs) > 0:
        ptr = ptrs.pop()

        if total_ptrs > 1:
            s_ = s.fork()
        else:
            s_ = s

        count = 0
        null = bv.Constant(ptr.size, 0)
        bytes = []

        not_terminated = None
        not_already_terminated = bl.Constant(True)
        while s_.solver.check(num > count):
            byte = s_.read(ptr + bv.Constant(ptr.size, count), 8)

            not_terminated = not_already_terminated & (byte == value)
            bytes.append((not_already_terminated, byte, count))

            if not_terminated.symbolic:
                not_already_terminated = bl.Symbol(unique_name('tmp'))
                s_.solver.add(not_already_terminated == not_terminated)
            else:
                not_already_terminated = not_terminated

            count += 1

        bytes.reverse()

        result = None
        prev_result = None
        for (not_already_terminated, byte, count) in bytes:
            if result is None:
                result = bv.if_then_else(
                            byte == value,
                            ptr + bv.Constant(ptr.size, count),
                            null)
            else:
                result = bv.if_then_else(
                            not_already_terminated,
                            bv.if_then_else(
                                byte == value,
                                ptr + bv.Constant(ptr.size, count),
                                prev_result),
                            prev_result)

            # this reduces the memory footprint_ of the resulting expression
            # significantly
            prev_result = bv.Symbol(ptr.size, unique_name('tmp'))
            s_.solver.add(prev_result == result)

        if result.symbolic:
            result_symbol = bv.Symbol(ptr.size, unique_name('memcmp'))
            s_.solver.add(result_symbol == result)
            result = result_symbol

        f_ = cc(s_)
        ss += f_.ret(value=result)

    return ss
예제 #14
0
파일: solver.py 프로젝트: shadown/smt
 def concretise(self):
     m = self.model()
     self._roots = []
     for symbol_name in m:
         symbol_value = m[symbol_name]
         self.add(bv.Symbol(symbol_value.size, symbol_name) == symbol_value)