Beispiel #1
0
def test_detect_syscall_wrapper():
    filepath = os.path.join(bin_location, "CROMU_00071")
    backend = DetourBackend(filepath)
    cfg = backend.cfg
    legitimate_syscall_wrappers = set([(0x804d483, 1), (0x804d491, 2),
                                       (0x804d4b1, 3), (0x804d4d1, 4),
                                       (0x804d4f7, 5), (0x804d511, 6),
                                       (0x804d525, 7)])

    syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \
            for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None])
    print "syscall wrappers in CROMU_00071:"
    print map(lambda x: (hex(x[0]), x[1]), syscall_wrappers)
    nose.tools.assert_equal(syscall_wrappers, legitimate_syscall_wrappers)

    filepath = os.path.join(bin_location, "CROMU_00070")
    backend = DetourBackend(filepath)
    cfg = backend.cfg
    legitimate_syscall_wrappers = set([(0x804d690, 5), (0x804d66a, 4),
                                       (0x804d6be, 7), (0x804d6aa, 6),
                                       (0x804d61c, 1), (0x804d64a, 3),
                                       (0x804d62a, 2)])

    syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \
            for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None])
    print "syscall wrappers in CROMU_00070:"
    print map(lambda x: (hex(x[0]), x[1]), syscall_wrappers)
    nose.tools.assert_equal(syscall_wrappers, legitimate_syscall_wrappers)
    def function_to_patch_locations(self, ff):
        # TODO tail-call is handled lazily just by considering jumping out functions as not sane
        if cfg_utils.is_sane_function(ff) and cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None \
                and not cfg_utils.is_floatingpoint_function(self.patcher,ff) and not ff.addr in self.safe_functions:
            if cfg_utils.is_longjmp(self.patcher, ff):
                self.found_longjmp = ff.addr
            elif cfg_utils.is_setjmp(self.patcher, ff):
                self.found_setjmp = ff.addr
            else:
                start = ff.startpoint
                ends = set()
                for ret_site in ff.ret_sites:
                    bb = self.patcher.project.factory.fresh_block(
                        ret_site.addr, ret_site.size)
                    last_instruction = bb.capstone.insns[-1]
                    if last_instruction.mnemonic != u"ret":
                        msg = "bb at %s does not terminate with a ret in function %s"
                        l.debug(msg % (hex(int(bb.addr)), ff.name))
                        break
                    else:
                        ends.add(last_instruction.address)
                else:
                    if len(ends) == 0:
                        l.debug("cannot find any ret in function %s" % ff.name)
                    else:
                        return int(start.addr), map(
                            int, ends)  #avoid "long" problems

        l.debug("function %s has problems and cannot be patched" % ff.name)
        return None, None
Beispiel #3
0
    def contains_executable_allocation(self, cfg):
        allocate_wrapper_addr = None
        for ff in cfg.functions.values():
            if cfg_utils.detect_syscall_wrapper(self.patcher, ff) == 5:
                allocate_wrapper_addr = ff.addr
                break
        if allocate_wrapper_addr == None:
            return False

        allocate_callers = self.inv_callsites[allocate_wrapper_addr]
        for bb_addr in allocate_callers:
            state = self.patcher.project.factory.entry_state(mode="fastpath",
                                                             addr=bb_addr)
            successors = state.step()
            all_succ = (successors.successors +
                        successors.unconstrained_successors)
            if len(all_succ) != 1:
                continue
            succ = all_succ[0]
            isx_flag = succ.state.mem[succ.state.regs.esp + 8].dword.resolved
            if not isx_flag.symbolic:
                if succ.state.solver.eval(isx_flag) == 1:
                    l.warning("found executable allocation, at %#8x" % bb_addr)
                    return True
        return False
Beispiel #4
0
    def get_patches(self):
        patches = []
        patches.extend(self.get_common_patches())
        cfg = self.patcher.cfg

        if self.contains_executable_allocation(cfg):
            l.warning(
                "found executable allocation, I will not apply indirect CFI")
            self.allocate_executable = True
        else:
            self.allocate_executable = False

        self.safe_addrs = self.get_safe_functions()

        # the overlapping instruction issue seems to be fixed, at least partially
        # I am still using a dict and raising warnings in case of problems.
        sci = {}
        for ff in cfg.functions.values():
            if not ff.is_syscall and ff.startpoint != None and ff.endpoints != None and \
                    cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None and \
                    not cfg_utils.is_floatingpoint_function(self.patcher,ff)\
                    and ff.addr not in self.safe_addrs:
                for bb in ff.blocks:
                    for ci in bb.capstone.insns:
                        if ci.group(
                                capstone.x86_const.X86_GRP_CALL) or ci.group(
                                    capstone.x86_const.X86_GRP_JUMP):
                            if len(ci.operands) != 1:
                                l.warning(
                                    "Unexpected operand size for CALL/JUMP: %s"
                                    % str(ci))
                            else:
                                op = ci.operands[0]
                                if op.type != capstone.x86_const.X86_OP_IMM:
                                    if ci.address in sci:
                                        old_ci = sci[ci.address]
                                        tstr = "instruction at %08x (bb: %08x, function %08x) " % \
                                                (ci.address,bb.addr,ff.addr)
                                        tstr += "previously found at bb: %08x in function: %08x" % \
                                                (old_ci[1].addr,old_ci[2].addr)
                                        l.warning(tstr)
                                    else:
                                        sci[ci.address] = (ci, bb, ff)

        for instruction, bb, ff in sci.values():
            l.info("Found indirect CALL/JUMP: %s" % str(instruction))
            cj_type = self.classify_cj(instruction)
            if cj_type == "standard":
                try:
                    new_patches = self.handle_standard_cj(instruction, ff)
                except utils.NasmException:
                    l.warning(
                        "NASM exception while compiling mem_access for %s" %
                        instruction)
                    continue
            patches.extend(new_patches)

        return patches
Beispiel #5
0
    def get_patches(self):
        cfg = self.patcher.cfg
        for k, ff in cfg.functions.items():

            if ff.is_syscall or ff.is_simprocedure:
                # don't patch syscalls or SimProcedures
                continue

            if not ff.is_syscall and ff.startpoint != None and ff.endpoints != None and \
                    cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None and \
                    not cfg_utils.is_floatingpoint_function(self.patcher,ff):
                call_sites = ff.get_call_sites()
                for cs in call_sites:
                    nn = ff.get_node(cs)
                    # max stack size is 8MB
                    if any([
                            0xba2aa000 <= n.addr < 0xbaaab000
                            for n in nn.successors()
                    ]):
                        l.warning("found call to stack at %#8x, avoiding nx" %
                                  nn.addr)
                        return []

            for block in ff.blocks:
                for s in block.vex.statements:
                    if any([
                            0xba2aa000 <= v.value <= 0xbaaab000
                            for v in s.constants
                    ]):
                        l.warning(
                            "found constant that looks stack-related at %#8x, avoiding nx"
                            % block.addr)
                        return []

        patches = []
        nxsegment_after_stack = (0x1, 0x0, 0xbaaab000, 0xbaaab000, 0x0, 0x1000,
                                 0x6, 0x1000)
        patches.append(
            AddSegmentHeaderPatch(nxsegment_after_stack,
                                  name="nxstack_segment_header"))
        added_code = '''
            ; int 3
            ; this can be placed before or after the stack shift
            add esp, 0x1000
            ; restore flags, assume eax=0 since we are after restore
            push 0x202
            popf
            mov DWORD [esp-4], eax
        '''
        patches.append(
            AddEntryPointPatch(added_code,
                               name="move_stack_to_nx",
                               after_restore=True))

        return patches
    def _block_calls_safe_syscalls(self, block, func_info, var):
        # checks that the block only calls sycalls that aren't receive
        # receive is the only one that stack ret encryption is useful for
        # also only checks this var

        target_kind = block.vex.constant_jump_targets_and_jumpkinds
        if len(target_kind) != 1:
            return False
        if target_kind.keys()[0] not in self.patcher.cfg.functions:
            return False
        target = self.patcher.cfg.functions[target_kind.keys()[0]]
        if cfg_utils.detect_syscall_wrapper(self.patcher, target) and \
                cfg_utils.detect_syscall_wrapper(self.patcher, target) != 3:
            return True

        # if it is receive we need to do extra checks
        if cfg_utils.detect_syscall_wrapper(self.patcher, target) and \
                        cfg_utils.detect_syscall_wrapper(self.patcher, target) == 3:

            # execute the block
            s = self.patcher.identifier.base_symbolic_state.copy()
            s.regs.ip = block.addr
            if func_info.bp_based:
                s.regs.bp = s.regs.sp + func_info.bp_sp_diff
            simgr = self.patcher.project.factory.simulation_manager(
                s, save_unconstrained=True)
            simgr.step()
            if len(simgr.active + simgr.unconstrained) > 0:
                succ = (simgr.active + simgr.unconstrained)[0]
                rx_arg = succ.mem[succ.regs.sp + 16].dword.resolved
                size_arg = succ.mem[succ.regs.sp + 12].dword.resolved
                # we say the rx_bytes arg is okay
                if not rx_arg.symbolic:
                    if func_info.bp_based:
                        rx_bytes_off = 0 - succ.se.eval(s.regs.bp - rx_arg)
                    else:
                        rx_bytes_off = succ.se.eval(rx_arg - s.regs.sp) - (
                            func_info.frame_size + 4) + 4
                    if rx_bytes_off == var:
                        return True

        return False
Beispiel #7
0
    def _should_skip(self, ff):
        if cfg_utils.detect_syscall_wrapper(
                self.patcher, ff) or ff.is_syscall or ff.startpoint is None:
            return True
        if cfg_utils.is_floatingpoint_function(self.patcher, ff):
            return True
        all_pred_addrs = set(x.addr for x in self.patcher.cfg.get_predecessors(
            self.patcher.cfg.get_any_node(ff.addr)))
        if len(all_pred_addrs) > 5:
            return True

        return False
Beispiel #8
0
    def get_patches(self):
        patches = []
        cfg = self.patcher.cfg

        receive_wrapper = [ff for ff in cfg.functions.values() if \
                cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 3]
        if len(receive_wrapper) != 1:
            l.warning(
                "Found %d receive_wrapper... better not to touch anything" %
                len(receive_wrapper))
            return []
        receive_wrapper = receive_wrapper[0]
        # here we assume that receive_wrapper is a "sane" syscall wrapper, as checked by detect_syscall_wrapper
        last_block = [
            b for b in receive_wrapper.blocks if b.addr != receive_wrapper.addr
        ][0]
        victim_addr = int(last_block.addr)
        syscall_addr = victim_addr - 2

        patches.extend(Bitflip.get_presyscall_patch(syscall_addr))
        patches.append(Bitflip.get_translation_table_patch())
        # free registers esi, edx, ecx, ebx are free because we are in a syscall wrapper restoring them
        # ebx: fd, ecx: buf, edx: count, esi: rx_byte
        code = '''
            test eax, eax ; receive succeded
            jne _exit_bitflip

            test ebx, ebx ; test if ebx is 0 (stdin)
            je _enter_bitflip
            cmp ebx, 1
            jne _exit_bitflip
            _enter_bitflip:

            %s

            _exit_bitflip:
        ''' % (Bitflip.get_bitflip_code())

        patches.append(
            InsertCodePatch(victim_addr,
                            code,
                            "postreceive_bitflip_patch",
                            priority=900))
        return patches
    def get_patches(self):
        patches = []
        cfg = self.patcher.cfg

        transmit_wrapper = [ff for ff in cfg.functions.values() if \
                cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 2] 
        if len(transmit_wrapper) != 1:
            l.warning("Found %d transmit_wrapper... better not to touch anything"%len(transmit_wrapper))
            return []
        transmit_wrapper = transmit_wrapper[0]
        victim_node = cfg.get_any_node(transmit_wrapper.addr)
        victim_addr = int(victim_node.instruction_addrs[-1])

        patches.extend(self.compute_patches(victim_addr))

        #import IPython; IPython.embed()


        return patches
Beispiel #10
0
    def get_patches(self):
        patches = []
        cfg = self.patcher.cfg

        receive_wrapper = [ff for ff in cfg.functions.values() if \
                cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 3]
        if len(receive_wrapper) != 1:
            l.warning(
                "Found %d receive_wrapper... better not to touch anything" %
                len(receive_wrapper))
            return []
        receive_wrapper = receive_wrapper[0]
        # here we assume that receive_wrapper is a "sane" syscall wrapper, as checked by detect_syscall_wrapper
        last_block = [
            b for b in receive_wrapper.blocks if b.addr != receive_wrapper.addr
        ][0]
        victim_addr = int(last_block.addr)
        patches.extend(self.compute_patches(victim_addr))
        if self.enable_bitflip:
            patches.extend(Bitflip.get_presyscall_patch(victim_addr - 2))
            patches.append(Bitflip.get_translation_table_patch())

        return patches
Beispiel #11
0
    def _handle_func(self, ff):
        if self._should_skip(ff):
            return

        func_info = self.ident.get_func_info(ff.addr)
        if func_info is None:
            return
        if func_info.frame_size > 0x2000:
            return

        inverted_stack_accesses = self._invert_stack_var_accesses(func_info)

        def_uninitialized_reads = set()
        possible_uninitialized_reads = set()

        to_process = [(ff.startpoint, set(), set())
                      ]  # block, seen addrs, written to stack vars
        while to_process:
            bl, seen, written = to_process.pop()
            seen.add(bl)

            cfg_node = self.patcher.cfg.get_any_node(bl.addr)
            if not cfg_node:
                continue
            insts = cfg_node.instruction_addrs

            for i in insts:
                if i in inverted_stack_accesses:
                    actions = inverted_stack_accesses[i]
                    for arg, action in actions:
                        if arg >= 0:
                            continue

                        if action == "write":
                            written.add(arg)
                        elif action == "read" and arg not in written and arg % 4 == 0:
                            def_uninitialized_reads.add(arg)
                        elif action == "load" and arg not in written and arg % 4 == 0:
                            # get the min > arg
                            subset_written = set(a for a in written if a > arg)
                            if len(subset_written) == 0:
                                the_next = 0
                            else:
                                the_next = min(subset_written)

                            uninitialized_size = the_next - arg

                            target_kind = self.patcher.project.factory.block(
                                i).vex.constant_jump_targets_and_jumpkinds
                            if len(target_kind) == 1 and target_kind.keys(
                            )[0] in self.patcher.cfg.functions:
                                call_target = self.patcher.cfg.functions[
                                    target_kind.keys()[0]]
                            else:
                                call_target = None

                            # if the target is a syscall wrapper (not transmit) it's safe
                            if call_target is None or \
                                    cfg_utils.detect_syscall_wrapper(self.patcher, call_target) == 2 or \
                                    not cfg_utils.detect_syscall_wrapper(self.patcher, call_target):

                                if uninitialized_size < 0x40:
                                    possible_uninitialized_reads.update(
                                        arg + x for x in xrange(
                                            0, uninitialized_size, 4))
                                    written.update(arg + x for x in xrange(
                                        0, uninitialized_size, 4))
                            else:
                                written.add(arg)

            succs = ff.graph.successors(bl)
            for s in succs:
                if s not in seen:
                    seen.add(s)
                    to_process.append((s, set(seen), set(written)))

        def_uninitialized_reads = sorted(def_uninitialized_reads)
        possible_uninitialized_reads = sorted(possible_uninitialized_reads)

        if len(def_uninitialized_reads) > 0:
            l.debug("definite uninitialized read by func %#x of vars %s",
                    ff.addr, map(hex, def_uninitialized_reads))

        if len(possible_uninitialized_reads) > 0:
            l.debug("possible uninitialized read by func %#x of vars %s",
                    ff.addr, map(hex, possible_uninitialized_reads))

        to_zero = sorted(def_uninitialized_reads +
                         possible_uninitialized_reads)[::-1]

        if len(to_zero) == 0:
            return

        to_zero = [x - 4 for x in to_zero]

        free_regs = set()
        for r in self.relevant_registers:
            if self.is_reg_free(ff.addr, r, False):
                free_regs.add(r)
        free_regs = list(free_regs)

        # note that all of to_zero should be < 0
        if any(v >= 0 for v in to_zero):
            return

        patch_name = "uninit_patch%#x" % ff.addr

        if len(to_zero) == 1:
            code = "mov DWORD [esp%#x], 0; " % to_zero[0]
            l.debug("adding:\n%s", code)
            patch = InsertCodePatch(ff.addr, code, patch_name, stackable=True)
            self.patches.append(patch)
            return

        groups = self._make_groups(to_zero)

        prefix = ""
        suffix = ""
        body = ""
        # use a reg as 0
        if len(free_regs) == 0:
            prefix += "push eax; \n"
            zero_reg = "eax"
            suffix += "pop eax; "
            groups = self._fix_groups(groups, 4)
        else:
            prefix += ""
            zero_reg = free_regs[0]
            suffix += ""
        prefix += "xor %s, %s; \n" % (zero_reg, zero_reg)

        offset_reg = "XXX"  # these should not be used
        offset_reg_curr = 0xffff  # these should not be used
        min_group_size = 3
        if any(len(g) >= min_group_size for g in groups):
            # use a register for the offset
            if len(free_regs) > 1:
                offset_reg = free_regs[1]
            else:
                prefix += "push edi; \n"
                offset_reg = "edi"
                suffix = "pop edi; \n" + suffix
                groups = self._fix_groups(groups, 4)
                if not any(len(g) >= min_group_size for g in groups):
                    min_group_size = min(len(g) for g in groups)

            first_group_off = next(g[0] for g in groups
                                   if len(g) >= min_group_size)
            prefix += "lea %s, [esp%#x]; \n" % (offset_reg, first_group_off)
            offset_reg_curr = first_group_off

        for g in groups:
            if len(g) < min_group_size:
                for off in g:
                    body += "mov DWORD [esp%#x], %s; \n" % (off, zero_reg)
            else:
                if offset_reg_curr != g[0]:
                    if g[0] - offset_reg_curr > 0:
                        body += "add %s, %#x; \n" % (offset_reg,
                                                     g[0] - offset_reg_curr)
                    else:
                        body += "sub %s, %#x; \n" % (offset_reg,
                                                     offset_reg_curr - g[0])
                offset_reg_curr = g[0]
                for off in g:
                    if off == offset_reg_curr:
                        body += "mov DWORD [%s], %s; \n" % (offset_reg,
                                                            zero_reg)
                    else:
                        if off - offset_reg_curr >= 0:
                            l.debug("bad error, skipping patch")
                            return
                        body += "mov DWORD [%s%#x], %s; \n" % (
                            offset_reg, off - offset_reg_curr, zero_reg)

        code = prefix + body + suffix
        l.debug("adding:\n%s", code)
        self.patches.append(
            InsertCodePatch(ff.addr, code, patch_name, stackable=True))