def test_detect_syscall_wrapper(): filepath = os.path.join(bin_location, "CROMU_00071") backend = DetourBackend(filepath) cfg = backend.cfg legitimate_syscall_wrappers = set([(0x804d483, 1), (0x804d491, 2), (0x804d4b1, 3), (0x804d4d1, 4), (0x804d4f7, 5), (0x804d511, 6), (0x804d525, 7)]) syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \ for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None]) print "syscall wrappers in CROMU_00071:" print map(lambda x: (hex(x[0]), x[1]), syscall_wrappers) nose.tools.assert_equal(syscall_wrappers, legitimate_syscall_wrappers) filepath = os.path.join(bin_location, "CROMU_00070") backend = DetourBackend(filepath) cfg = backend.cfg legitimate_syscall_wrappers = set([(0x804d690, 5), (0x804d66a, 4), (0x804d6be, 7), (0x804d6aa, 6), (0x804d61c, 1), (0x804d64a, 3), (0x804d62a, 2)]) syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \ for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None]) print "syscall wrappers in CROMU_00070:" print map(lambda x: (hex(x[0]), x[1]), syscall_wrappers) nose.tools.assert_equal(syscall_wrappers, legitimate_syscall_wrappers)
def function_to_patch_locations(self, ff): # TODO tail-call is handled lazily just by considering jumping out functions as not sane if cfg_utils.is_sane_function(ff) and cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None \ and not cfg_utils.is_floatingpoint_function(self.patcher,ff) and not ff.addr in self.safe_functions: if cfg_utils.is_longjmp(self.patcher, ff): self.found_longjmp = ff.addr elif cfg_utils.is_setjmp(self.patcher, ff): self.found_setjmp = ff.addr else: start = ff.startpoint ends = set() for ret_site in ff.ret_sites: bb = self.patcher.project.factory.fresh_block( ret_site.addr, ret_site.size) last_instruction = bb.capstone.insns[-1] if last_instruction.mnemonic != u"ret": msg = "bb at %s does not terminate with a ret in function %s" l.debug(msg % (hex(int(bb.addr)), ff.name)) break else: ends.add(last_instruction.address) else: if len(ends) == 0: l.debug("cannot find any ret in function %s" % ff.name) else: return int(start.addr), map( int, ends) #avoid "long" problems l.debug("function %s has problems and cannot be patched" % ff.name) return None, None
def contains_executable_allocation(self, cfg): allocate_wrapper_addr = None for ff in cfg.functions.values(): if cfg_utils.detect_syscall_wrapper(self.patcher, ff) == 5: allocate_wrapper_addr = ff.addr break if allocate_wrapper_addr == None: return False allocate_callers = self.inv_callsites[allocate_wrapper_addr] for bb_addr in allocate_callers: state = self.patcher.project.factory.entry_state(mode="fastpath", addr=bb_addr) successors = state.step() all_succ = (successors.successors + successors.unconstrained_successors) if len(all_succ) != 1: continue succ = all_succ[0] isx_flag = succ.state.mem[succ.state.regs.esp + 8].dword.resolved if not isx_flag.symbolic: if succ.state.solver.eval(isx_flag) == 1: l.warning("found executable allocation, at %#8x" % bb_addr) return True return False
def get_patches(self): patches = [] patches.extend(self.get_common_patches()) cfg = self.patcher.cfg if self.contains_executable_allocation(cfg): l.warning( "found executable allocation, I will not apply indirect CFI") self.allocate_executable = True else: self.allocate_executable = False self.safe_addrs = self.get_safe_functions() # the overlapping instruction issue seems to be fixed, at least partially # I am still using a dict and raising warnings in case of problems. sci = {} for ff in cfg.functions.values(): if not ff.is_syscall and ff.startpoint != None and ff.endpoints != None and \ cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None and \ not cfg_utils.is_floatingpoint_function(self.patcher,ff)\ and ff.addr not in self.safe_addrs: for bb in ff.blocks: for ci in bb.capstone.insns: if ci.group( capstone.x86_const.X86_GRP_CALL) or ci.group( capstone.x86_const.X86_GRP_JUMP): if len(ci.operands) != 1: l.warning( "Unexpected operand size for CALL/JUMP: %s" % str(ci)) else: op = ci.operands[0] if op.type != capstone.x86_const.X86_OP_IMM: if ci.address in sci: old_ci = sci[ci.address] tstr = "instruction at %08x (bb: %08x, function %08x) " % \ (ci.address,bb.addr,ff.addr) tstr += "previously found at bb: %08x in function: %08x" % \ (old_ci[1].addr,old_ci[2].addr) l.warning(tstr) else: sci[ci.address] = (ci, bb, ff) for instruction, bb, ff in sci.values(): l.info("Found indirect CALL/JUMP: %s" % str(instruction)) cj_type = self.classify_cj(instruction) if cj_type == "standard": try: new_patches = self.handle_standard_cj(instruction, ff) except utils.NasmException: l.warning( "NASM exception while compiling mem_access for %s" % instruction) continue patches.extend(new_patches) return patches
def get_patches(self): cfg = self.patcher.cfg for k, ff in cfg.functions.items(): if ff.is_syscall or ff.is_simprocedure: # don't patch syscalls or SimProcedures continue if not ff.is_syscall and ff.startpoint != None and ff.endpoints != None and \ cfg_utils.detect_syscall_wrapper(self.patcher,ff) == None and \ not cfg_utils.is_floatingpoint_function(self.patcher,ff): call_sites = ff.get_call_sites() for cs in call_sites: nn = ff.get_node(cs) # max stack size is 8MB if any([ 0xba2aa000 <= n.addr < 0xbaaab000 for n in nn.successors() ]): l.warning("found call to stack at %#8x, avoiding nx" % nn.addr) return [] for block in ff.blocks: for s in block.vex.statements: if any([ 0xba2aa000 <= v.value <= 0xbaaab000 for v in s.constants ]): l.warning( "found constant that looks stack-related at %#8x, avoiding nx" % block.addr) return [] patches = [] nxsegment_after_stack = (0x1, 0x0, 0xbaaab000, 0xbaaab000, 0x0, 0x1000, 0x6, 0x1000) patches.append( AddSegmentHeaderPatch(nxsegment_after_stack, name="nxstack_segment_header")) added_code = ''' ; int 3 ; this can be placed before or after the stack shift add esp, 0x1000 ; restore flags, assume eax=0 since we are after restore push 0x202 popf mov DWORD [esp-4], eax ''' patches.append( AddEntryPointPatch(added_code, name="move_stack_to_nx", after_restore=True)) return patches
def _block_calls_safe_syscalls(self, block, func_info, var): # checks that the block only calls sycalls that aren't receive # receive is the only one that stack ret encryption is useful for # also only checks this var target_kind = block.vex.constant_jump_targets_and_jumpkinds if len(target_kind) != 1: return False if target_kind.keys()[0] not in self.patcher.cfg.functions: return False target = self.patcher.cfg.functions[target_kind.keys()[0]] if cfg_utils.detect_syscall_wrapper(self.patcher, target) and \ cfg_utils.detect_syscall_wrapper(self.patcher, target) != 3: return True # if it is receive we need to do extra checks if cfg_utils.detect_syscall_wrapper(self.patcher, target) and \ cfg_utils.detect_syscall_wrapper(self.patcher, target) == 3: # execute the block s = self.patcher.identifier.base_symbolic_state.copy() s.regs.ip = block.addr if func_info.bp_based: s.regs.bp = s.regs.sp + func_info.bp_sp_diff simgr = self.patcher.project.factory.simulation_manager( s, save_unconstrained=True) simgr.step() if len(simgr.active + simgr.unconstrained) > 0: succ = (simgr.active + simgr.unconstrained)[0] rx_arg = succ.mem[succ.regs.sp + 16].dword.resolved size_arg = succ.mem[succ.regs.sp + 12].dword.resolved # we say the rx_bytes arg is okay if not rx_arg.symbolic: if func_info.bp_based: rx_bytes_off = 0 - succ.se.eval(s.regs.bp - rx_arg) else: rx_bytes_off = succ.se.eval(rx_arg - s.regs.sp) - ( func_info.frame_size + 4) + 4 if rx_bytes_off == var: return True return False
def _should_skip(self, ff): if cfg_utils.detect_syscall_wrapper( self.patcher, ff) or ff.is_syscall or ff.startpoint is None: return True if cfg_utils.is_floatingpoint_function(self.patcher, ff): return True all_pred_addrs = set(x.addr for x in self.patcher.cfg.get_predecessors( self.patcher.cfg.get_any_node(ff.addr))) if len(all_pred_addrs) > 5: return True return False
def get_patches(self): patches = [] cfg = self.patcher.cfg receive_wrapper = [ff for ff in cfg.functions.values() if \ cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 3] if len(receive_wrapper) != 1: l.warning( "Found %d receive_wrapper... better not to touch anything" % len(receive_wrapper)) return [] receive_wrapper = receive_wrapper[0] # here we assume that receive_wrapper is a "sane" syscall wrapper, as checked by detect_syscall_wrapper last_block = [ b for b in receive_wrapper.blocks if b.addr != receive_wrapper.addr ][0] victim_addr = int(last_block.addr) syscall_addr = victim_addr - 2 patches.extend(Bitflip.get_presyscall_patch(syscall_addr)) patches.append(Bitflip.get_translation_table_patch()) # free registers esi, edx, ecx, ebx are free because we are in a syscall wrapper restoring them # ebx: fd, ecx: buf, edx: count, esi: rx_byte code = ''' test eax, eax ; receive succeded jne _exit_bitflip test ebx, ebx ; test if ebx is 0 (stdin) je _enter_bitflip cmp ebx, 1 jne _exit_bitflip _enter_bitflip: %s _exit_bitflip: ''' % (Bitflip.get_bitflip_code()) patches.append( InsertCodePatch(victim_addr, code, "postreceive_bitflip_patch", priority=900)) return patches
def get_patches(self): patches = [] cfg = self.patcher.cfg transmit_wrapper = [ff for ff in cfg.functions.values() if \ cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 2] if len(transmit_wrapper) != 1: l.warning("Found %d transmit_wrapper... better not to touch anything"%len(transmit_wrapper)) return [] transmit_wrapper = transmit_wrapper[0] victim_node = cfg.get_any_node(transmit_wrapper.addr) victim_addr = int(victim_node.instruction_addrs[-1]) patches.extend(self.compute_patches(victim_addr)) #import IPython; IPython.embed() return patches
def get_patches(self): patches = [] cfg = self.patcher.cfg receive_wrapper = [ff for ff in cfg.functions.values() if \ cfg_utils.detect_syscall_wrapper(self.patcher,ff) == 3] if len(receive_wrapper) != 1: l.warning( "Found %d receive_wrapper... better not to touch anything" % len(receive_wrapper)) return [] receive_wrapper = receive_wrapper[0] # here we assume that receive_wrapper is a "sane" syscall wrapper, as checked by detect_syscall_wrapper last_block = [ b for b in receive_wrapper.blocks if b.addr != receive_wrapper.addr ][0] victim_addr = int(last_block.addr) patches.extend(self.compute_patches(victim_addr)) if self.enable_bitflip: patches.extend(Bitflip.get_presyscall_patch(victim_addr - 2)) patches.append(Bitflip.get_translation_table_patch()) return patches
def _handle_func(self, ff): if self._should_skip(ff): return func_info = self.ident.get_func_info(ff.addr) if func_info is None: return if func_info.frame_size > 0x2000: return inverted_stack_accesses = self._invert_stack_var_accesses(func_info) def_uninitialized_reads = set() possible_uninitialized_reads = set() to_process = [(ff.startpoint, set(), set()) ] # block, seen addrs, written to stack vars while to_process: bl, seen, written = to_process.pop() seen.add(bl) cfg_node = self.patcher.cfg.get_any_node(bl.addr) if not cfg_node: continue insts = cfg_node.instruction_addrs for i in insts: if i in inverted_stack_accesses: actions = inverted_stack_accesses[i] for arg, action in actions: if arg >= 0: continue if action == "write": written.add(arg) elif action == "read" and arg not in written and arg % 4 == 0: def_uninitialized_reads.add(arg) elif action == "load" and arg not in written and arg % 4 == 0: # get the min > arg subset_written = set(a for a in written if a > arg) if len(subset_written) == 0: the_next = 0 else: the_next = min(subset_written) uninitialized_size = the_next - arg target_kind = self.patcher.project.factory.block( i).vex.constant_jump_targets_and_jumpkinds if len(target_kind) == 1 and target_kind.keys( )[0] in self.patcher.cfg.functions: call_target = self.patcher.cfg.functions[ target_kind.keys()[0]] else: call_target = None # if the target is a syscall wrapper (not transmit) it's safe if call_target is None or \ cfg_utils.detect_syscall_wrapper(self.patcher, call_target) == 2 or \ not cfg_utils.detect_syscall_wrapper(self.patcher, call_target): if uninitialized_size < 0x40: possible_uninitialized_reads.update( arg + x for x in xrange( 0, uninitialized_size, 4)) written.update(arg + x for x in xrange( 0, uninitialized_size, 4)) else: written.add(arg) succs = ff.graph.successors(bl) for s in succs: if s not in seen: seen.add(s) to_process.append((s, set(seen), set(written))) def_uninitialized_reads = sorted(def_uninitialized_reads) possible_uninitialized_reads = sorted(possible_uninitialized_reads) if len(def_uninitialized_reads) > 0: l.debug("definite uninitialized read by func %#x of vars %s", ff.addr, map(hex, def_uninitialized_reads)) if len(possible_uninitialized_reads) > 0: l.debug("possible uninitialized read by func %#x of vars %s", ff.addr, map(hex, possible_uninitialized_reads)) to_zero = sorted(def_uninitialized_reads + possible_uninitialized_reads)[::-1] if len(to_zero) == 0: return to_zero = [x - 4 for x in to_zero] free_regs = set() for r in self.relevant_registers: if self.is_reg_free(ff.addr, r, False): free_regs.add(r) free_regs = list(free_regs) # note that all of to_zero should be < 0 if any(v >= 0 for v in to_zero): return patch_name = "uninit_patch%#x" % ff.addr if len(to_zero) == 1: code = "mov DWORD [esp%#x], 0; " % to_zero[0] l.debug("adding:\n%s", code) patch = InsertCodePatch(ff.addr, code, patch_name, stackable=True) self.patches.append(patch) return groups = self._make_groups(to_zero) prefix = "" suffix = "" body = "" # use a reg as 0 if len(free_regs) == 0: prefix += "push eax; \n" zero_reg = "eax" suffix += "pop eax; " groups = self._fix_groups(groups, 4) else: prefix += "" zero_reg = free_regs[0] suffix += "" prefix += "xor %s, %s; \n" % (zero_reg, zero_reg) offset_reg = "XXX" # these should not be used offset_reg_curr = 0xffff # these should not be used min_group_size = 3 if any(len(g) >= min_group_size for g in groups): # use a register for the offset if len(free_regs) > 1: offset_reg = free_regs[1] else: prefix += "push edi; \n" offset_reg = "edi" suffix = "pop edi; \n" + suffix groups = self._fix_groups(groups, 4) if not any(len(g) >= min_group_size for g in groups): min_group_size = min(len(g) for g in groups) first_group_off = next(g[0] for g in groups if len(g) >= min_group_size) prefix += "lea %s, [esp%#x]; \n" % (offset_reg, first_group_off) offset_reg_curr = first_group_off for g in groups: if len(g) < min_group_size: for off in g: body += "mov DWORD [esp%#x], %s; \n" % (off, zero_reg) else: if offset_reg_curr != g[0]: if g[0] - offset_reg_curr > 0: body += "add %s, %#x; \n" % (offset_reg, g[0] - offset_reg_curr) else: body += "sub %s, %#x; \n" % (offset_reg, offset_reg_curr - g[0]) offset_reg_curr = g[0] for off in g: if off == offset_reg_curr: body += "mov DWORD [%s], %s; \n" % (offset_reg, zero_reg) else: if off - offset_reg_curr >= 0: l.debug("bad error, skipping patch") return body += "mov DWORD [%s%#x], %s; \n" % ( offset_reg, off - offset_reg_curr, zero_reg) code = prefix + body + suffix l.debug("adding:\n%s", code) self.patches.append( InsertCodePatch(ff.addr, code, patch_name, stackable=True))