def instrument_init_array(self): #XXX: yet todo # this one is old version for x86, needs to be converted section = self.rewriter.container.sections[".init_array"] constructor = DataCell.instrumented(".quad {}".format(sp.ASAN_INIT_FN), 8) section.cache.append(constructor) if ".fini_array" not in self.rewriter.container.sections: finiarr = DataSection(".fini_array", 0, 0, "") self.rewriter.container.add_section(finiarr) fini = self.rewriter.container.sections[".fini_array"] destructor = DataCell.instrumented( ".quad {}".format(sp.ASAN_DEINIT_FN), 8) fini.cache.append(destructor) initfn = Function(sp.ASAN_INIT_FN, ASAN_INIT_LOC, 0, "") initfn.set_instrumented() initcode = InstrumentedInstruction( '\n'.join(sp.MODULE_INIT).format(global_count=self.global_count), None, None) initfn.cache.append(initcode) self.rewriter.container.add_function(initfn) finifn = Function(sp.ASAN_DEINIT_FN, ASAN_DEINIT_LOC, 0, "") finifn.set_instrumented() finicode = InstrumentedInstruction( '\n'.join(sp.MODULE_DEINIT).format(global_count=self.global_count), None, None) finifn.cache.append(finicode) self.rewriter.container.add_function(finifn)
def get_mem_instrumentation(self, instruction, idx, free): enter_lbl = "COUNTER_%x" % (instruction.address) instrumentation = trampoline_fmt_arm.format( random=random.randint(0, MAP_SIZE)) comment = "{}: {}".format(str(instruction), str(free)) return InstrumentedInstruction(instrumentation, enter_lbl, comment)
def instrument_persistent(self, addr): # we need to call __afl_persistent_instrumentation instead of main # so we hijack the _start start_func = None for faddr, fn in self.rewriter.container.functions.items(): if fn.name == "_start": start_func = fn for idx, instruction in enumerate(start_func.cache): # passing main as first argument if "ldr x0, [x0" in str(instruction): instruction.mnemonic = "// " + instruction.mnemonic prev_instruction = fn.cache[idx - 1] prev_instruction.mnemonic = "// " + prev_instruction.mnemonic instrumentation = "adrp x0, __afl_persistent_instrumentation\nadd x0, x0, :lo12:__afl_persistent_instrumentation" instruction.instrument_after( InstrumentedInstruction(instrumentation))
def count_one(self, instruction, idx, free): enter_lbl = "COUNTER_%x" % (instruction.address) instrumentation = """ stp x7, x8, [sp, -16]! // save x7, x8 // build a pointer in x8 to .counted adrp x8, .counted add x8, x8, :lo12:.counted // add 1 to .counted ldr x7, [x8] add x7, x7, 1 str x7, [x8] ldp x7, x8, [sp], 16 // load back x7 and x8 """ comment = "{}: ".format(str(instruction)) return InstrumentedInstruction(instrumentation, enter_lbl, comment)
def unpoison_stack(self, args, need_save): instrumentation = list() # Save the register we're about to clobber if need_save: instrumentation.append(copy.copy(sp.MEM_REG_SAVE)[0]) # Add instrumentation to poison instrumentation.extend(copy.copy(sp.STACK_POISON_BASE)) args["off"] = ASAN_SHADOW_OFF instrumentation.append(copy.copy(sp.STACK_UNPOISON_SLOT)) # Restore clobbered register if need_save: instrumentation.append(copy.copy(sp.MEM_REG_RESTORE)[0]) code_str = "\n".join(instrumentation).format(**args) return InstrumentedInstruction(code_str, sp.STACK_EXIT_LBL.format(**args), None)
def do_instrument(self): for faddr, fn in self.rewriter.container.functions.items(): for idx, instruction in enumerate(fn.cache): # if any("adrp" in str(x) for x in instruction.before): if "br" in instruction.mnemonic: iinstr = self.count_one(instruction, idx, None) instruction.instrument_before(iinstr) if "blr" in instruction.mnemonic: iinstr = self.count_two(instruction, idx, None) instruction.instrument_before(iinstr) ds = Section(".counter", 0x100000, 0, None, flags="aw") content = """ .file: .string \"/tmp/countfile\" .perms: .string \"a\" .format: .string \"br: %lld\\nblr: %lld\\n\" .align 3 .counted: .quad 0x0 .counted2: .quad 0x0 """ ds.cache.append(DataCell.instrumented(content, 0)) self.rewriter.container.add_data_section(ds) ds = Section(".finiamola", 0x200000, 0, None, flags="ax") ds.align = 0 instrumentation = """ // build a pointer to .perms adrp x1, .perms add x1, x1, :lo12:.perms // build a pointer to .file adrp x0, .file add x0, x0, :lo12:.file // call the libc fopen(.file, .perms) bl fopen // load .counted in x2 adrp x2, .counted ldr x2, [x2, :lo12:.counted] // load .counted in x3 adrp x3, .counted2 ldr x3, [x3, :lo12:.counted2] // build a pointer to .format adrp x1, .format add x1, x1, :lo12:.format // fprintf( fopen("/tmp/countfile", "a+"), "%lld", counted); bl fprintf bl exit """ ds.cache.append(DataCell.instrumented(instrumentation, 0)) self.rewriter.container.add_data_section(ds) self.rewriter.container.datasections[".fini_array"].cache.append( DataCell.instrumented(".quad .finiamola", 0)) f = self.rewriter.container.codesections[".fini"].functions[0] self.rewriter.container.functions[f].cache[0].instrument_before( InstrumentedInstruction(instrumentation, 0))
def get_mem_instrumentation(self, acsz, instruction, midx, free, is_leaf, bool_load, rbase_reg): if "sp" in instruction.reg_reads() or "sp" in instruction.reg_writes(): debug("we do not instrument push/pop for now") return InstrumentedInstruction("# not instrumented - push/pop") if "x29" in instruction.reg_reads() or "x29" in instruction.reg_writes( ): debug("we do not instrument stack frames for now") return InstrumentedInstruction( "# not instrumented - stackframe push/pop") if "x28" in instruction.reg_reads() or "x28" in instruction.reg_writes( ): debug("we do not instrument stack frames for now") return InstrumentedInstruction( "# not instrumented - stackframe push/pop") #XXX: this should skip only other ASAN instrumentation, not any instrumentation in general if len(instruction.before) > 0 or len(instruction.after) > 0 or\ (instruction.mnemonic == "ldr" and "=" in instruction.op_str): return InstrumentedInstruction( "# Already instrumented - skipping bASAN") # we prefer high registers, less likely to go wrong affinity = ["x" + str(i) for i in range(17, -1, -1)] # do not use registers used by the very same instruction! for reg in instruction.reg_reads(): reg64 = get_64bits_reg(reg) if is_reg_32bits(reg) else reg if reg64 in affinity: affinity.remove(reg64) free_regs = sorted(list(free), key=lambda x: affinity.index(x) if x in affinity else len(affinity)) # *very rarely* a leaf function does not respect the ABI, # and we cannot assume we have free registers. As an example, look at the # call to et_splay in et_set_father in the gcc_r speccpu benchmark # This could be optimized by only restricting to regiters present in the leaf function if is_leaf: free_regs = [] # do not use registers used for batching if rbase_reg and rbase_reg in free_regs: print(affinity) print(rbase_reg) free_regs.remove(rbase_reg) codecache = list() save = list() restore = list() fix_lexp = list() save_rflags = "unopt" save_rax = True push_cnt = 0 # XXX: for access sizes 8 and 16: # we need one register less num_regs_used = 5 asan_regs = free_regs[:num_regs_used] # we need 4 free registers if len( asan_regs ) < num_regs_used: # if there aren't enough we save them on the stack non_free = [reg for reg in affinity if reg not in asan_regs] to_save_regs = non_free[:num_regs_used - len(asan_regs)] i = 0 for i in range(0, len(to_save_regs) - 1, 2): # first save them in pairs (faster) save.append( copy.copy(sp.STACK_PAIR_REG_SAVE)[0].format( *to_save_regs[i:i + 2])) restore.insert( 0, copy.copy(sp.STACK_PAIR_REG_LOAD)[0].format( *to_save_regs[i:i + 2])) if len(to_save_regs) % 2 == 1: # if there is a single one left save.append( copy.copy(sp.STACK_REG_SAVE)[0].format(to_save_regs[-1])) restore.insert( 0, copy.copy(sp.STACK_REG_LOAD)[0].format(to_save_regs[-1])) asan_regs += to_save_regs push_cnt += len(to_save_regs) save_condition_reg = True # if acsz < 8: if save_condition_reg: # XXX: should check whether we actually need this or not save.append("\tmrs {0}, nzcv".format(asan_regs[4])) restore.insert(0, "\tmsr nzcv, {0}".format(asan_regs[4])) mem, mem_op_idx = instruction.get_mem_access_op() mem_op = instruction.cs.operands[mem_op_idx] cs = instruction.cs lexp = asan_regs[0] # the first free register # ldr x0, [x1, x2, LSL#3] if mem_op.shift.value != 0: amnt = mem_op.shift.value to, sxtw = asan_regs[0], "" shift_reg = cs.reg_name(mem.index) if is_reg_32bits(shift_reg): to, sxtw = self._get_subreg32(asan_regs[0]), ", sxtw" fix_lexp += [ sp.LEXP_SHIFT.format(To=to, Res=asan_regs[0], From=cs.reg_name(mem.base), amnt=amnt, shift_reg=shift_reg, sxtw=sxtw) ] # ldr x0, [x1, x2] elif mem.index != 0: reg_index = cs.reg_name(mem.index) if is_reg_32bits(reg_index): reg_index += ", sxtw" # quick hack fix_lexp += [ sp.LEXP_ADD.format(To=asan_regs[0], From=cs.reg_name(mem.base), amnt=reg_index) ] # ldr x0, [x1, #12] elif mem.disp != 0: if mem.disp > (1 << 12): # aarch64 limitation fix_lexp += [ sp.LEXP_MOVZ.format(To=asan_regs[0], amnt=mem.disp) ] fix_lexp += [ sp.LEXP_ADD.format(To=asan_regs[0], From=cs.reg_name(mem.base), amnt=asan_regs[0]) ] else: fix_lexp += [ sp.LEXP_ADD.format(To=asan_regs[0], From=cs.reg_name(mem.base), amnt=mem.disp) ] # ldr x0, [x1] if mem.disp == 0 and mem.index == 0 and mem_op.shift.value == 0: lexp = cs.reg_name(mem.base) if "rflags" in free: save_rflags = False free.remove("rflags") if "rax" in free: save_rax = False # XXX: # if len(free) > 0: # r2 = [False, "%{}".format(free[0])] # if len(free) > 1: # r1 = [False, "%{}".format(free[1])] # if r2[1] == r1[1]: # r1 = [True, "x1"] # if save_rflags: # save_rflags = "opt" # save_rax = "rax" not in free #XXX: should remove this, we dont use red zones on the stack yet if is_leaf: save.insert(0, sp.LEAF_STACK_ADJUST) restore.append(sp.LEAF_STACK_UNADJUST) # this has to do with red zones (kernel x64 does not have them) (are you sure?) # https://github.com/torvalds/linux/blob/9f159ae07f07fc540290f219372 # if is_leaf and (any([r[0] for r in asan_regs]) or save_rflags): # save.append(sp.LEAF_STACK_ADJUST) # restore.append(sp.LEAF_STACK_UNADJUST) # push_cnt += 32 if push_cnt > 0 and '%rsp' in lexp: # In this case we have a stack-relative load but the value of the stack # pointer has changed because we pushed some registers to the stack # to save them. Adjust the displacement of the access to take this # into account # XXX ARM: save.append("leaq {}(%rsp), %rsp".format(push_cnt * 8)) restore.insert(0, "leaq -{}(%rsp), %rsp".format(push_cnt * 8)) memcheck = "" if not rbase_reg: # we could not batch the ASAN base memcheck += copy.copy(sp.ASAN_BASE) if acsz == 1: memcheck += copy.copy(sp.MEM_LOAD_1) elif acsz == 2: memcheck += copy.copy(sp.MEM_LOAD_2) elif acsz == 4: memcheck += copy.copy(sp.MEM_LOAD_4) elif acsz == 8: memcheck += copy.copy(sp.MEM_LOAD_8) elif acsz == 16: memcheck += copy.copy(sp.MEM_LOAD_16) else: assert False, "Reached unreachable code!" memcheck += copy.copy(sp.ASAN_REPORT) codecache.extend(save) if len(fix_lexp): codecache.append('\n'.join(fix_lexp)) codecache.append(memcheck) codecache.append(sp.MEM_EXIT_LABEL) # codecache.append( # copy.copy(sp.MEM_EXIT_LABEL)[0].format(addr=instruction.address)) codecache.extend(restore) args = dict() args["lexp"] = lexp args["acsz"] = acsz args["r1"] = asan_regs[1] args["r1_32"] = self._get_subreg32(asan_regs[1]) args["r2"] = asan_regs[2] args["r2_32"] = self._get_subreg32(asan_regs[2]) args["rbase"] = rbase_reg if rbase_reg else asan_regs[3] args["addr"] = instruction.address enter_lbl = "%s_%x" % (sp.ASAN_MEM_ENTER, instruction.address) args['acctype'] = 'load' if bool_load else 'store' codecache = '\n'.join(codecache) comment = "{}: {}".format(str(instruction), str(free)) return InstrumentedInstruction(codecache.format(**args), enter_lbl, comment)
def handle_longjmp(self, instruction): args = dict(reg="%r9", addr=instruction.address, off=ASAN_SHADOW_OFF) unpoison = ("\n".join(copy.copy(sp.LONGJMP_UNPOISON))).format(**args) instrument = InstrumentedInstruction(unpoison, None, None) instruction.instrument_before(instrument)
def instrument_mem_accesses(self): jumps_instrumentation = list() for _, fn in self.rewriter.container.functions.items(): # if any([s in fn.name for s in ["alloc", "signal_is_trapped", "free", "gimplify"]]): # info(f"Skipping instrumentation on function {fn.name} to avoid custom heap implementations") # continue if not len(fn.cache): continue is_leaf = fn.analysis.get(StackFrameAnalysis.KEY_IS_LEAF, False) # First, we analyze basic blocks, to check if we can batch instrumentation for e, addr in enumerate(sorted(fn.bbstarts)): bb_start = addr bb_end = sorted( fn.bbstarts)[e + 1] - INSTR_SIZE if e + 1 < len( fn.bbstarts) else fn.cache[-1].address if bb_start not in fn.addr_to_idx: critical( f"basic block error: {hex(addr)} not in function {fn.name}, starting at {hex(fn.cache[0].address)}, ending at {hex(fn.cache[-1].address)}" ) continue if bb_end not in fn.addr_to_idx: critical( f"basic block error: {hex(bb_end)} not in function {fn.name}, starting at {hex(fn.cache[0].address)}, ending at {hex(fn.cache[-1].address)}" ) continue first_instruction = fn.cache[fn.addr_to_idx[bb_start]] last_instruction = fn.cache[fn.addr_to_idx[bb_end]] to_instrument = [] regs = set(non_clobbered_registers) regs.remove("x0") # cannot overwrite return value # compute the free regs of this basic block for addr in range(bb_start, bb_end + INSTR_SIZE, INSTR_SIZE): idx = fn.addr_to_idx[addr] free_registers = fn.analysis['free_registers'][idx] if fn.cache[idx].mnemonic.startswith( "bl"): # XXX: WTFFFFFF regs = set() regs = regs.intersection(free_registers) # compute which instructions we are going to instrument (to_instrument) for addr in range(bb_start, bb_end + INSTR_SIZE, INSTR_SIZE): idx = fn.addr_to_idx[addr] instruction = fn.cache[idx] if isinstance(instruction, InstrumentedInstruction) or \ instruction.address in self.skip_instrument: continue mem, midx = instruction.get_mem_access_op() if not mem: # This is not a memory access continue # XXX: not supporting SIMD instructions for now (ld1, st3 ...) if instruction.cs.reg_name( instruction.cs.operands[0].reg)[0] == 'v': debug( f"Skipping BAsan instrumentation on SIMD instr {instruction.cs}" ) continue # look at start of get_mem_instrumentation() if any([ x in instruction.reg_reads() + instruction.reg_writes() for x in ["sp", "x29", "x28"] ]): continue free_registers = fn.analysis['free_registers'][idx] acsz, bool_load = get_access_size_arm(instruction.cs) if acsz not in [1, 2, 4, 8, 16]: critical( f"[*] Missed an access: {instruction} -- {acsz}") continue self.mem_instrumentation_stats[fn.start].append(idx) to_instrument += [(acsz, instruction, midx, free_registers, is_leaf, bool_load)] # if len(regs) and len(to_instrument) > 1: # rbase_reg = sorted(regs)[-1] # first_instruction.instrument_before(InstrumentedInstruction(f"mov {rbase_reg}, 0x1000000000")) # else: # rbase_reg = None rbase_reg = None # now we actually instrument the selected instructions for acsz, instruction, midx, free_registers, is_leaf, bool_load in to_instrument: debug(f"{instruction} --- acsz: {acsz}, load: {bool_load}") iinstr = self.get_mem_instrumentation( acsz, instruction, midx, free_registers, is_leaf, bool_load, rbase_reg) jumps_instrumentation += [ f".I{hex(instruction.address)[2:]}:" ] jumps_instrumentation += [iinstr] jumps_instrumentation += [ f"b .LC{hex(instruction.address)[2:]} + 4" ] instruction.instrument_before( InstrumentedInstruction( f"b .I{hex(instruction.address)[2:]}")) debug(f"{hex(bb_start)}, {hex(bb_end)}") debug(f"{regs}") return "\n".join(jumps_instrumentation)