Exemplo n.º 1
0
    def instrument_init_array(self):
        section = self.rewriter.container.sections[".init_array"]
        constructor = DataCell.instrumented(".quad {}".format(sp.ASAN_INIT_FN),
                                            8)
        section.cache.append(constructor)

        if ".fini_array" not in self.rewriter.container.sections:
            finiarr = DataSection(".fini_array", 0, 0, "")
            self.rewriter.container.add_section(finiarr)

        fini = self.rewriter.container.sections[".fini_array"]
        destructor = DataCell.instrumented(
            ".quad {}".format(sp.ASAN_DEINIT_FN), 8)
        fini.cache.append(destructor)

        initfn = Function(sp.ASAN_INIT_FN, ASAN_INIT_LOC, 0, "")
        initfn.set_instrumented()
        initcode = InstrumentedInstruction(
            '\n'.join(sp.MODULE_INIT).format(global_count=self.global_count),
            None, None)
        initfn.cache.append(initcode)
        self.rewriter.container.add_function(initfn)

        finifn = Function(sp.ASAN_DEINIT_FN, ASAN_DEINIT_LOC, 0, "")
        finifn.set_instrumented()
        finicode = InstrumentedInstruction(
            '\n'.join(sp.MODULE_DEINIT).format(global_count=self.global_count),
            None, None)
        finifn.cache.append(finicode)
        self.rewriter.container.add_function(finifn)
Exemplo n.º 2
0
    def do_instrument(self):
        for fn in self.rewriter.container.iter_functions():
            fn.set_instrumented()

            for iidx, instr in enumerate(fn.cache):
                if instr.address in fn.bbstarts:
                    iinstr = []
                    # free_regs = fn.analysis['free_registers'][iidx]
                    # flags_are_free = 'rflags' in free_regs

                    # regs_to_save = [
                    #     r for r in CALLER_SAVED_REGS
                    #     if r not in free_regs
                    # ]
                    regs_to_save = CALLER_SAVED_REGS

                    # if not flags_are_free:
                    iinstr.append('\tpushfq')

                    for reg in regs_to_save:
                        iinstr.append('\tpushq %{}'.format(reg))

                    # Keep the stack pointer aligned
                    # used_stack_slots = regs_to_save if flags_are_free else regs_to_save + 1

                    # if (used_stack_offset % 2) != 0:
                    #     iinstr.append('\tsubq $8, %rsp')

                    iinstr.append('\tcallq __sanitizer_cov_trace_pc')

                    # if (used_stack_offset % 2) != 0:
                    #     iinstr.append('\taddq $8, %rsp')

                    for reg in regs_to_save[::-1]:
                        iinstr.append('\tpopq %{}'.format(reg))

                    # if not flags_are_free:
                    iinstr.append('\tpopfq')

                    if instr.address.offset == 0:
                        # this needs to go after the stack pointer adjustment
                        instr.instrument_after(
                            InstrumentedInstruction('\n'.join(iinstr)))
                    else:
                        instr.instrument_before(
                            InstrumentedInstruction('\n'.join(iinstr)))
Exemplo n.º 3
0
    def poison_stack(self, args, need_save, instruction, midx):
        instrumentation = list()

        # Save the registers we're about to clobber
        if need_save > 1:
            instrumentation.append(
                copy.copy(sp.MEM_REG_SAVE)[0].replace('{reg}', '{reg1}'))

        if need_save > 0:
            disp = instruction.cs.operands[midx].value.mem.disp
            adjusted_disp = disp + need_save * 8

            instrumentation.append(
                copy.copy(sp.MEM_REG_SAVE)[0].replace('{reg}', '{reg2}'))

            if '%rsp' in args['pbase']:
                if hex(disp) in args['pbase']:
                    args['pbase'] = args['pbase'].replace(
                        hex(disp), hex(adjusted_disp))
                elif str(disp) in args['pbase']:
                    args['pbase'] = args['pbase'].replace(
                        str(disp), hex(adjusted_disp))
                else:
                    import pdb
                    pdb.set_trace()
                    assert False, 'Can\'t find displacement in pbase'

        # Add instrumentation to poison
        instrumentation.extend(copy.copy(sp.STACK_POISON_BASE))

        args["off"] = ASAN_SHADOW_OFF
        instrumentation.append(copy.copy(sp.STACK_POISON_SLOT))

        # Restore clobbered registers
        if need_save > 0:
            instrumentation.append(
                copy.copy(sp.MEM_REG_RESTORE)[0].replace('{reg}', '{reg2}'))

        if need_save > 1:
            instrumentation.append(
                copy.copy(sp.MEM_REG_RESTORE)[0].replace('{reg}', '{reg1}'))

        code_str = "\n".join(instrumentation).format(**args)
        return InstrumentedInstruction(code_str,
                                       sp.STACK_ENTER_LBL.format(**args), None)
Exemplo n.º 4
0
    def do_instrument(self):
        for fn in self.rewriter.container.iter_functions():
            fn.set_instrumented()

            for iidx, instr in enumerate(fn.cache):
                if instr.address in fn.bbstarts:
                    iinstr = []
                    free_regs = fn.analysis['free_registers'][iidx]
                    flags_are_free = 'rflags' in free_regs

                    regs_to_save = [
                        r for r in KcovInstrument.CALLER_SAVED_REGS
                        if r not in free_regs
                    ]

                    if not flags_are_free:
                        iinstr.append('\tpushfq')

                    for reg in regs_to_save:
                        iinstr.append('\tpushq %{}'.format(reg))

                    # Keep the stack pointer aligned
                    used_stack_slots = len(
                        regs_to_save
                    ) if flags_are_free else len(regs_to_save) + 1

                    if (used_stack_slots % 2) != 0:
                        iinstr.append('\tsubq $8, %rsp')

                    iinstr.append('\tcallq __sanitizer_cov_trace_pc')

                    if (used_stack_slots % 2) != 0:
                        iinstr.append('\taddq $8, %rsp')

                    for reg in regs_to_save[::-1]:
                        iinstr.append('\tpopq %{}'.format(reg))

                    if not flags_are_free:
                        iinstr.append('\tpopfq')

                    instr.instrument_before(
                        InstrumentedInstruction('\n'.join(iinstr)))
Exemplo n.º 5
0
 def handle_longjmp(self, instruction):
     args = dict(reg="%r9", addr=instruction.address, off=ASAN_SHADOW_OFF)
     unpoison = ("\n".join(copy.copy(sp.LONGJMP_UNPOISON))).format(**args)
     instrument = InstrumentedInstruction(unpoison, None, None)
     instruction.instrument_before(instrument)
Exemplo n.º 6
0
    def get_mem_instrumentation(self, acsz, instruction, midx, free, is_leaf):
        affinity = [
            "rdi", "rsi", "rcx", "rdx", "rbx", "r8", "r9", "r10", "r11", "r12",
            "r13", "r14", "r15", "rax", "rbp"
        ]

        free = sorted(list(free),
                      key=lambda x: affinity.index(x)
                      if x in affinity else len(affinity))

        codecache = list()
        save = list()
        restore = list()
        save_rflags = "unopt"
        save_rax = True
        r1 = [True, "%rdi"]
        r2 = [True, "%rsi"]
        push_cnt = 0

        # XXX: Bug in capstone?
        if any([
                instruction.mnemonic.startswith(x)
                for x in ["sar", "shl", "shl", "stos", "shr", "rep stos"]
        ]):
            midx = 1

        is_rep_stos = False
        if instruction.mnemonic.startswith("rep stos"):
            is_rep_stos = True
            lexp = instruction.op_str.split(",", 1)[1]
        elif len(instruction.cs.operands) == 1:
            lexp = instruction.op_str
        elif len(instruction.cs.operands) > 2:
            print("[*] Found op len > 2: %s" % (instruction))
            op1 = instruction.op_str.split(",", 1)[1]
            lexp = op1.rsplit(",", 1)[0]
        elif midx == 0:
            lexp = instruction.op_str.rsplit(",", 1)[0]
        else:
            lexp = instruction.op_str.split(",", 1)[1]

        if lexp.startswith("*"):
            lexp = lexp[1:]

        if "rflags" in free:
            save_rflags = False
            free.remove("rflags")

        if "rax" in free:
            save_rax = False

        if len(free) > 0:
            r2 = [False, "%{}".format(free[0])]
            if len(free) > 1:
                r1 = [False, "%{}".format(free[1])]

            if r2[1] == r1[1]:
                r1 = [True, "%rsi"]

            if save_rflags:
                save_rflags = "opt"
                save_rax = "rax" not in free

        # Linux on x86_64 doesn't use red zones
        # https://github.com/torvalds/linux/blob/9f159ae07f07fc540290f21937231034f554bdd7/arch/x86/Makefile#L132
        # if is_leaf and (r1[0] or r2[0] or save_rflags):
        #     save.append(sp.LEAF_STACK_ADJUST)
        #     restore.append(sp.LEAF_STACK_UNADJUST)
        #     push_cnt += 32

        if r1[0]:
            save.append(copy.copy(sp.MEM_REG_SAVE)[0].format(reg=r1[1]))
            restore.insert(0,
                           copy.copy(sp.MEM_REG_RESTORE)[0].format(reg=r1[1]))
            push_cnt += 1

        if r2[0]:
            save.append(copy.copy(sp.MEM_REG_SAVE)[0].format(reg=r2[1]))
            restore.insert(0,
                           copy.copy(sp.MEM_REG_RESTORE)[0].format(reg=r2[1]))
            push_cnt += 1

        if save_rflags == "unopt":
            save.append(copy.copy(sp.MEM_FLAG_SAVE)[0])
            restore.insert(0, copy.copy(sp.MEM_FLAG_RESTORE)[0])
            push_cnt += 1
        elif save_rflags == "opt":
            push_cnt += 1
            if save_rax:
                save.append(
                    copy.copy(sp.MEM_REG_REG_SAVE_RESTORE)[0].format(
                        src="%rax", dst=r2[1]))
                save.extend(copy.copy(sp.MEM_FLAG_SAVE_OPT))

                save.append(
                    copy.copy(sp.MEM_REG_REG_SAVE_RESTORE)[0].format(
                        dst="%rax", src=r2[1]))

                restore.insert(
                    0,
                    copy.copy(sp.MEM_REG_REG_SAVE_RESTORE)[0].format(
                        dst="%rax", src=r2[1]))

                restore = copy.copy(sp.MEM_FLAG_RESTORE_OPT) + restore

                restore.insert(
                    0,
                    copy.copy(sp.MEM_REG_REG_SAVE_RESTORE)[0].format(
                        src="%rax", dst=r2[1]))
            else:
                save.extend(copy.copy(sp.MEM_FLAG_SAVE_OPT))
                restore = copy.copy(sp.MEM_FLAG_RESTORE_OPT) + restore

        if push_cnt > 0 and '%rsp' in lexp:
            # In this case we have a stack-relative load but the value of the stack
            # pointer has changed because we pushed some registers to the stack
            # to save them. Adjust the displacement of the access to take this
            # into account
            disp = instruction.cs.operands[midx].value.mem.disp
            adjusted_disp = disp + push_cnt * 8

            if hex(disp) in lexp:
                lexp = lexp.replace(hex(disp), hex(adjusted_disp))
            elif str(disp) in lexp:
                lexp = lexp.replace(str(disp), hex(adjusted_disp))
            else:
                assert False, 'Can\'t find displacement in lexp'

        if acsz == 1:
            memcheck = self._access1()
        elif acsz == 2:
            memcheck = self._access2()
        elif acsz == 4:
            memcheck = self._access4()
        elif acsz == 8:
            memcheck = self._access8()
        else:
            assert False, "Reached unreachable code!"

        codecache.extend(save)
        codecache.append(memcheck)
        codecache.append(
            copy.copy(sp.MEM_EXIT_LABEL)[0].format(addr=instruction.address))
        codecache.extend(restore)

        args = dict()
        args["lexp"] = lexp
        args["acsz"] = acsz
        args["acsz_1"] = acsz - 1

        args["clob1"] = r1[1]
        args["clob1_32"] = "%{}".format(self._get_subreg32(r1[1][1:]))
        args["clob1_8"] = "%{}".format(self._get_subreg8l(r1[1][1:]))

        args["tgt"] = r2[1]
        args["tgt_32"] = "%{}".format(self._get_subreg32(r2[1][1:]))
        args["tgt_8"] = "%{}".format(self._get_subreg8l(r2[1][1:]))

        args["addr"] = instruction.address
        enter_lbl = "%s_%s" % (sp.ASAN_MEM_ENTER, instruction.address)

        args['save_regs'] = ''
        args['restore_regs'] = ''
        args['acctype'] = 'load' if midx == 0 else 'store'

        codecache = '\n'.join(codecache)
        comment = "{}: {}".format(str(instruction), str(free))

        if is_rep_stos:
            copycache = copy.copy(codecache)
            extend_args_check = copy.copy(args)
            extend_args_check["lexp"] = "(%rdi, %rcx)"
            extend_args_check["addr"] = '{}_2'.format(instruction.address)
            copycache = copycache.format(**extend_args_check)
            original_exit = copy.copy(
                sp.MEM_EXIT_LABEL)[0].format(addr=instruction.address)

            new_exit = copy.copy(sp.MEM_EXIT_LABEL)[0].format(
                addr='{}_2'.format(instruction.address))
            copycache = copycache.replace(original_exit, new_exit)
            codecache = codecache + "\n" + copycache

        return InstrumentedInstruction(codecache.format(**args), enter_lbl,
                                       comment)