def __init__(self, program, address):
        self._program = program
        self._flat_api = FlatProgramAPI(self._program)
        self._data = self._flat_api.getDataAt(address)

        self.address = address
        self.type = None
        self.value = None

        if not self._data:
            return

        # Determine the data's type and set the type/value accordingly.
        if self._data.isPointer():
            reference = self._data.getPrimaryReference(0)
            if not reference:
                return

            to_addr = reference.toAddress
            func = self._flat_api.getFunctionAt(to_addr)
            if func:
                self.type = 'function'
                self.value = func
            else:
                value = self._flat_api.getDataAt(to_addr)
                if value:
                    self.type = 'data'
                    self.value = value
        else:
            self.type = self._data.dataType
            self.value = self._data
def main():
    try:
        rs_program = askProgram("Choose the restart-service.exe executable")
        ntdll_program = askProgram("Choose the ntdll.dll shared library")
    except CancelledException:
        return

    hashes = [
        0x3944AA7E,
        0x7EA69E72,
        0xDBA7A248,
        0x57FF1EA4,
        0x71948CA4,
    ]

    rs_api = FlatProgramAPI(rs_program)

    addresses = {
        0x3944AA7E: rs_api.toAddr(0x004ab004),
        0x7EA69E72: rs_api.toAddr(0x004ab008),
        0xDBA7A248: rs_api.toAddr(0x004ab00c),
        0x57FF1EA4: rs_api.toAddr(0x004ab010),
        0x71948CA4: rs_api.toAddr(0x004ab014),
    }

    a = AddressExplorer(ntdll_program)
    m = rs_program.getMemory()
    t = ProgramTransaction.open(rs_program, "Writing matched syscall IDs")

    for func_name, func_addr in zip(a.name_iter(), a.func_addr_iter()):
        func = getFunctionAt(func_addr)
        func_name_orig = func.getName() if func else None
        try:
            syscall_dword = getInt(toAddr(func_addr.getOffset() + 4))
        except MemoryAccessException:
            syscall_dword = None
        s = fnv1a_32(func_name.getBytes()[2:-1]) if func_name else 0x0
        s_orig = fnv1a_32([ord(i) for i in func_name_orig[2:]
                           ]) if func_name_orig else 0x0

        if (s in hashes) and func_name and func_name.getValue().startswith(
                "Nt") and syscall_dword is not None:
            print(
                "name: %s, name_orig: %s, func_addr: %s, syscall: %s, s: 0x%X, s_orig: 0x%X"
                % (func_name, func_name_orig, func_addr, syscall_dword, s,
                   s_orig))
            print("Writing %s to 0x%X" % (addresses[s], syscall_dword))
            m.setInt(addresses[s], syscall_dword)

    t.commit()
    t.close()
Esempio n. 3
0
    def __init__(self, plugin, state=None, logger_fname="ghidra_emulator.txt"):
        self.plugin = plugin
        self.monitor = self.plugin.getMonitor()
        if state is None:
            state = self.plugin.getGhidraState()
        program = state.getCurrentProgram()
        address = state.getCurrentAddress()

        self.byte_substitution = {}

        self.initLogger(logger_fname)
        self.initEmulator(program, address)
        self.initCmdHandlers()

        self.emulator_state = EmulatorState.WAITING_FOR_PARAM
        self.flatapi = FlatProgramAPI(program)
    def __init__(self, verbose, api_db):
        self.verbose = verbose
        self.api_db_path = api_db if os.path.isfile(api_db) else None
        self.api_db = None

        symbol_table = currentProgram.getSymbolTable()

        gpa = symbol_table.getExternalSymbol("GetProcAddress")
        self.gpa_refs = gpa.getReferences() if gpa else []  # list of references (code or data) to GetProcAddress
        if len(self.gpa_refs) > 0 and self.gpa_refs[0].getReferenceType().isData():
            self.gpa_refs.extend(getReferencesTo(self.gpa_refs[0].getFromAddress()))

        # List of references (code or data) to LoadLibrary (and similar functions)
        ll_refs = []
        ll_symbols = ["LoadLibraryA", "LoadLibraryW",
                      "LoadLibraryExA", "LoadLibraryExW",
                      "GetModuleHandleA", "GetModuleHandleW"]
        for symbol in ll_symbols:
            ext_symbol = symbol_table.getExternalSymbol(symbol)
            if ext_symbol:
                ll_refs.extend(ext_symbol.getReferences())
        self.load_lib_functions = DynamicImportsEnumerator.__get_functions_containing_refs(ll_refs)

        # Regex for matching GetProcAddress arguments
        self.getprocaddress_regex = re.compile(r"GetProcAddress[ \t]?\("
                                               r"[ \t]?(.+?)[ \t]?,"  # 1st parameter (hModule)
                                               r"[ \t]?(?:\(.+?\))?[ \t]?(.+?)[ \t]?[),]")  # 2nd parameter (lpProcName)
        # Regex for matching LoadLibrary (et similia) first argument (i.e. the name of the dll)
        self.loadlibrary_regex = \
            r"{}\s?=\s?(?:LoadLibrary(?:Ex)?|GetModuleHandle)(?:[AW])?\s?\(\s?(?:\(.+?\))?\s?(.+?)\s?[),]"
        # Regex for handling wrong hmodule bug in Ghidra decompiler
        self.wrong_hmodule_regex = r"hModule(_[0-9]*)"
        # Regex for matching C function call and their arguments
        self.function_call_regex = r"{}[ \t]?\([ \t]?(.*)[ \t]?\)"
        # Regex for matching casts to HMODULE
        self.hmodule_cast_regex = r"(?:\*[ \t]?\([ \t]?HMODULE[ \t]?\*\))?\(?(?:\(.*\))?{}\)?"
        # Regex for matching variable alias assignments
        self.alias_regex = r"([a-zA-Z_][a-zA-Z0-9_]*)[ \t]?=[ \t]?(?:\(.+?\)\s*)?{}[ \t]*;"
        self.alias_regex2 = r"{}[ \t]?=[ \t]?([a-zA-Z_][a-zA-Z0-9_]*)[ \t]*;"
        # Regex for matching internal functions
        self.internal_function_regex = r"((?:FUN|UndefinedFunction)\_[a-fA-F0-9]+)"
        self.internal_call_regex = re.compile(self.function_call_regex.format(self.internal_function_regex))
        # Regex for matching names of both internal functions and string/memory copy/move API functions
        self.call_regex = r"((?:FUN|UndefinedFunction)\_[a-fA-F0-9]+|" \
                          r"str(?:n)?cpy(?:_s)?|w?mem(?:cpy|move)(?:_s)?|basic_string<>)"
        self.call_regex = re.compile(self.function_call_regex.format(self.call_regex))

        # Initializing the decompiler
        self.flat_program = FlatProgramAPI(currentProgram)
        self.flat_decompiler = FlatDecompilerAPI(self.flat_program)
        decompiler_options = DecompileOptions()
        # Decompilation of some programs requires more memory than the default 50 MiB payload size.
        decompiler_options.setMaxPayloadMBytes(200)
        self.flat_decompiler.initialize()  # Explicit initialization is required for setting the options
        self.flat_decompiler.getDecompiler().setOptions(decompiler_options)
        self.decompiled_functions = {}
Esempio n. 5
0
    def __init__(self, program):
        self._program = program
        self._flat_api = FlatProgramAPI(self._program)
        self._memory_map = self._program.getMemory()
        self._simple_blk = BasicBlockModel(self._program)
        self._monitor = self._flat_api.getMonitor()
        self._function_manager = self._program.getFunctionManager()
        self._address_factory = self._program.getAddressFactory()

        self.signatures = None
        self._strings = {}
        self._find_strings()

        start = time.time()
        self._signatures = self._generate()
        end = time.time()

        print 'Generated %d formal signatures and %d fuzzy signatures for %d '\
            'functions in %.2f seconds.' % (len(self._signatures.formal),
                                            len(self._signatures.fuzzy),
                                            len(self._signatures.functions),
                                            end - start)
Esempio n. 6
0
class FinderBase(object):
    def __init__(self, program):
        self._program = program
        self._flat_api = FlatProgramAPI(program)
        self._monitor = self._flat_api.getMonitor()
        self._basic_blocks = BasicBlockModel(self._program)

    def _display(self, title, entries):
        """
        Print a simple table to the terminal.

        :param title: Title of the table.
        :type title: list

        :param entries: Entries to print in the table.
        :type entries: list(list(str))
        """
        lines = [title] + entries

        # Find the largest entry in each column so it can be used later
        # for the format string.
        max_line_len = []
        for i in range(0, len(title)):
            column_lengths = [len(line[i]) for line in lines]
            max_line_len.append(max(column_lengths))

        # Account for largest entry, spaces, and '|' characters on each line.
        separator = '=' * (sum(max_line_len) + (len(title) *
                                                (len(title) - 1)) + 1)
        spacer = '|'
        format_specifier = '{:<{width}}'

        # First block prints the title and '=' characters to make a title
        # border
        print separator
        print spacer,
        for width, column in zip(max_line_len, title):
            print format_specifier.format(column, width=width),
            print spacer,
        print ''
        print separator

        # Print the actual entries.
        for entry in entries:
            print spacer,
            for width, column in zip(max_line_len, entry):
                print format_specifier.format(column, width=width),
                print spacer,
            print ''
        print separator
Esempio n. 7
0
    def init(self, cp=None):

        if hasattr(globals(), 'currentProgram'):
            self.cp = currentProgram
        elif cp:
            self.cp = cp

        if self.cp:
            self.fmgr = self.cp.getFunctionManager()
            self.afac = self.cp.getAddressFactory()
            self.dtmgr = self.cp.getDataTypeManager()
            self.mem = self.cp.getMemory()
            self.st = self.cp.getSymbolTable()
            self.cm = self.cp.getCodeManager()
            self.flatapi = FlatProgramAPI(self.cp)
Esempio n. 8
0
    def _get_argument_count(self):
        """
        Get argument count through decompiler if possible otherwise try to
        determine the argument count manually. Manual approach can miss
        arguments if they are used in the first function call of the function.
        """
        flat_api = FlatProgramAPI(currentProgram)
        decompiler_api = FlatDecompilerAPI(flat_api)

        # Must call decompile first or the decompiler will not be initialized.
        decompiler_api.decompile(self.function)
        decompiler = decompiler_api.getDecompiler()

        if decompiler:
            decompiled_fn = decompiler.decompileFunction(
                self.function, 10, getMonitor())
            if decompiled_fn:
                high_level_fn = decompiled_fn.getHighFunction()
                if high_level_fn:
                    prototype = high_level_fn.getFunctionPrototype()
                    if prototype:
                        return prototype.getNumParams()

        return self._get_argument_count_manual()
Esempio n. 9
0
class Syscalls:

    flatProgram = None
    symEval = None
    monitor = None
    currentProgram = None
    currentSelection = None

    # Find all the places where the system call appears
    def getSyscalls(self, opcode, mnemonic):

        calls = []
        listing = self.currentProgram.getListing()
        locations = self.flatProgram.findBytes(
            self.currentProgram.getMinAddress(), opcode, 8192)

        for addr in locations:

            if self.monitor.isCancelled():
                return self.doCancel()

            ins = listing.getCodeUnitAt(addr)
            if ins is None:
                continue

            if mnemonic in ins.toString():
                calls.append(addr)

        return calls

    # Get a register value at a certain address through symbolic propagation
    def getRegisterValue(self, addr, register):

        function = self.currentProgram.getListing().getFunctionContaining(addr)
        evaluate = ConstantPropagationContextEvaluator(True)

        if function is None:
            return None

        self.symEval.flowConstants(function.getEntryPoint(),
                                   function.getBody(), evaluate, False,
                                   self.monitor)

        result = self.symEval.getRegisterValue(addr, register)
        if result is not None:
            return result.getValue()

        return None

    # Get a function signature
    def getSignature(self, name, data):

        return '%s %s(%s)' % (data['ret'], name, ', '.join(data['args']))

    # Mark the positions of arguments for the current syscall
    def markArguments(self, args, addr, data):

        listing = self.currentProgram.getListing()
        function = self.currentProgram.getListing().getFunctionContaining(addr)
        evaluate = ConstantPropagationContextEvaluator(True)

        if function is None:
            return None

        ins = listing.getCodeUnitAt(addr)
        if ins is None:
            return

        for block in function.getBody():
            if addr >= block.getMinAddress() and addr <= block.getMaxAddress():
                base = block.getMinAddress()
                break

        start = ins.getAddress()
        curr = ins.getPrevious()
        while curr != None:

            if curr.getFlowType().toString() != 'FALL_THROUGH':
                break

            start = curr.getAddress()
            if curr.getAddress().equals(base):
                break

            curr = curr.getPrevious()

        args = args[0:len(data['args'])]
        affected = {}
        for arg in args:
            affected[arg] = {
                'addr': None,
                'comment': data['args'][args.index(arg)]
            }

        curr = listing.getCodeUnitAt(start)
        while curr != None:

            addy = curr.getAddress()
            if curr.getAddress().equals(addr):
                break

            for which in curr.getResultObjects():
                if which.toString() in affected:
                    affected[which.toString()]['addr'] = addy

            curr = curr.getNext()

        for which in affected:
            if affected[which]['addr'] is None:
                continue
            self.flatProgram.setPostComment(affected[which]['addr'],
                                            affected[which]['comment'])

    # Cancel message
    def doCancel(self):

        print 'Operation cancelled'

    # Load data in a multilayered manner
    def loadData(self, kind, arch, abi):

        final = None
        layers = ['generic', arch, arch + '_' + abi]
        for layer in layers:

            filepath = os.path.dirname(os.path.realpath(__file__))
            filepath = '%s/../data/syscalls/%s_%s.json' % (filepath, layer,
                                                           kind)
            filename = os.path.realpath(filepath)

            if os.path.isfile(filename):
                data = None
                with open(filename) as file:
                    data = json.loads(file.read())

                if data is None:
                    continue

                if final is None:
                    final = data

        return final

    def __init__(self, program, selection, monitor, arch, abi='default'):

        self.currentProgram = program
        self.currentSelection = selection
        self.monitor = monitor
        self.flatProgram = FlatProgramAPI(program, monitor)
        self.symEval = SymbolicPropogator(self.currentProgram)

        if self.currentProgram.getExecutableFormat() != ElfLoader.ELF_NAME:
            popup('Not an ELF file, cannot continue')
            return

        if arch not in ARCHS:
            popup('Architecture not defined')
            return

        if abi not in ARCHS[arch]:
            popup('ABI not defined')
            return

        global SYSCALLS, FUNCTIONS
        SYSCALLS = self.loadData('syscalls', arch, abi)
        FUNCTIONS = self.loadData('functions', arch, abi)

        data = ARCHS[arch][abi]
        endian = self.currentProgram.getLanguage().getLanguageDescription(
        ).getEndian().toString()

        for row in data['ins']:

            if row['endian'] != endian:
                continue

            calls = self.getSyscalls(row['opcode'], row['interrupt'])
            for call in calls:

                if self.currentSelection is not None:
                    if call < self.currentSelection.getMinAddress():
                        continue
                    if call > self.currentSelection.getMaxAddress():
                        continue

                reg = self.currentProgram.getRegister(data['reg'])
                res = self.getRegisterValue(call, reg)

                if res is None:
                    continue

                res = str(res)
                if res not in SYSCALLS:
                    continue

                syscall = SYSCALLS[res]
                comment = syscall

                if syscall in FUNCTIONS:
                    comment = self.getSignature(syscall, FUNCTIONS[syscall])
                    self.markArguments(data['arg'], call, FUNCTIONS[syscall])

                self.flatProgram.setEOLComment(call, comment)
                self.flatProgram.createBookmark(
                    call, 'Syscall', 'Found %s -- %s' % (syscall, comment))
Esempio n. 10
0
def main():
    target_function = getFunctionContaining(currentAddress)
    if target_function is None:
        print("Please place the cursor within a function!")
        return

    rm = currentProgram.getReferenceManager()
    fapi = FlatProgramAPI(currentProgram)
    target_address = target_function.getBody().getMinAddress()

    references = []

    for i, ref in enumerate(rm.getReferencesTo(target_address)):
        references += [ref]

    maximum = len(references)
    monitor.setIndeterminate(False)
    monitor.initialize(maximum)
    monitor.setCancelEnabled(True)

    monitor.setProgress(0)
    monitor.setMessage("Fixing up %s references..." % target_function.getName())

    references = sorted(references, key=lambda x: x.getFromAddress())

    edit_count = 0

    for cur, ref in enumerate(references):
        if cur > 1 and cur % 1000 == 0 and edit_count > 100:
            # let auto analysis catch up
            time.sleep(10.0)
            edit_count = 0

        if target_function.hasNoReturn():
            print("Function has gone into no-return!")
            break

        monitor.setProgress(cur + 1)

        caddr = ref.getFromAddress()

        insn = getInstructionAt(caddr)
        insn_next = insn.getNext()
        foverride = insn.getFlowOverride()

        op = insn.toString().split(" ")[0]
        op_next = insn_next.toString().split(" ")[0]

        if not op.startswith("bl"):
            print("[%d/%d] [%s] skipping non bl/blx" % (cur + 1, maximum, caddr))
            continue

        # if insn_next.getFlowType() != FlowType.FALL_THROUGH:
            #print("[%d/%d] [%s] skipping %s" % (cur+1, maximum, caddr, insn_next.getFlowType()))
            # continue

        # if op_next == "b":
        #	iter = rm.getReferencesTo(insn_next.getAddress())
        #	if isinstance(iter, ghidra.program.database.references.EmptyMemReferenceIterator):
            #	print("Skipping badness @ %s" % caddr)
            #	#fapi.clearListing(caddr, caddr.add(4))
            # break
            # else:
            #print("Potential badness @ %s" % caddr)

        if foverride == FlowOverride.CALL_RETURN:
            print("[%d/%d] [%s] fixing up" % (cur + 1, maximum, caddr))

            pctx = insn.getInstructionContext().getProcessorContext()
            tmode_reg = pctx.getRegister("TMode")
            tmode = pctx.getRegisterValue(tmode_reg)
            tmode = tmode.getUnsignedValue()

            cmd = ArmDisassembleCommand(caddr.add(4), None, tmode == 1)
            fcmd = SetFlowOverrideCmd(caddr, FlowOverride.NONE)

            # slow path
            # if op_next == "b":
            #	print("SLOW")
            #	fapi.clearListing(caddr, caddr.add(insn.getLength()))
            if not cmd.applyTo(currentProgram):
                print("Failed to submit disassembly cmd for %s" % caddr)
            	continue

            # else:
            if not fcmd.applyTo(currentProgram):
                print("Failed to submit flow override cmd for %s" % caddr)
                continue

            edit_count += 1

            # if cur > 500:
            # time.sleep(2.0)
            # time.sleep(0.5)

            # insn.setFlowOverride(FlowOverride.NONE)
        elif foverride == FlowOverride.NONE:
            print("[%d/%d] [%s] skipping, no override" % (cur + 1, maximum, caddr))
        else:
            print("[%d/%d] [%s] unexpected override %s" %
                  (cur + 1, maximum, caddr, foverride))
Esempio n. 11
0
class AddressExplorer(object):
    def __init__(self, currentProgram):
        self.p = currentProgram
        self.a = FlatProgramAPI(self.p)
        self.image_base_offset = self.p.getAddressMap().imageBase.getOffset()

    def image_directory_entry_export_addr(self):
        offs1 = self.a.getInt(self.a.toAddr(self.image_base_offset + 0x3c))
        return self.a.toAddr(self.image_base_offset + 0x88 + offs1)

    def image_directory_entry_export(self):
        return self.a.getInt(self.image_directory_entry_export_addr())

    def export_number_of_names_addr(self):
        return self.a.toAddr(self.image_base_offset + 0x18 +
                             self.image_directory_entry_export())

    def export_number_of_names(self):
        return self.a.getInt(self.export_number_of_names_addr())

    def address_of_names(self):
        return self.a.toAddr(self.image_base_offset + 0x20 +
                             self.image_directory_entry_export())

    def address_of_names_offset(self):
        return self.a.getInt(self.address_of_names())

    def v3_addr(self):
        offs1_addr = self.address_of_names()
        print("v3_addr:offs1_addr = AddressOfNames Offset Address = %s" %
              offs1_addr)
        offs1 = self.address_of_names_offset()
        print("v3_addr:offs1 = AddressOfNames Offset = 0x%x" % offs1)
        print("v3_addr:offs1 = AddressOfNames Address = %s" %
              self.a.toAddr(self.image_base_offset + offs1))
        offs2_addr = self.a.toAddr(self.image_base_offset +
                                   (self.export_number_of_names() - 1) * 4 +
                                   offs1)
        print("v3_addr:offs2_addr = %s" % offs2_addr)
        offs2 = self.a.getInt(offs2_addr)
        print("v3_addr:offs2 = 0x%x" % offs2)
        return self.a.toAddr(self.image_base_offset + 2 + offs2)

    def get_dll_name(self):
        return self.a.getDataAt(
            self.a.toAddr(
                self.a.getInt(
                    self.a.toAddr(self.image_base_offset + 0x0c +
                                  self.image_directory_entry_export())) +
                self.image_base_offset))

    def name_iter(self):
        number_of_names = self.export_number_of_names()
        address_of_names_offset = self.address_of_names_offset()
        while number_of_names >= 0:
            number_of_names -= 1
            next_name_addr_offset = self.a.getInt(
                self.a.toAddr(self.image_base_offset + number_of_names * 4 +
                              address_of_names_offset))
            yield self.a.getDataAt(
                self.a.toAddr(self.image_base_offset + next_name_addr_offset))

    def address_of_name_ordinals_offset(self):
        return self.a.getInt(
            self.a.toAddr(self.image_base_offset + 0x24 +
                          self.image_directory_entry_export()))

    def address_of_functions_offset(self):
        return self.a.getInt(
            self.a.toAddr(self.image_base_offset + 0x1c +
                          self.image_directory_entry_export()))

    def func_addr_iter(self):
        number_of_names = self.export_number_of_names()
        address_of_name_ordinals_offset = self.address_of_name_ordinals_offset(
        )
        address_of_functions_offset = self.address_of_functions_offset()
        while number_of_names >= 0:
            number_of_names -= 1
            next_ordinal_addr_offset = self.a.getShort(
                self.a.toAddr(self.image_base_offset + number_of_names * 2 +
                              address_of_name_ordinals_offset))
            next_func_offset = self.a.getInt(
                self.a.toAddr(self.image_base_offset +
                              address_of_functions_offset +
                              next_ordinal_addr_offset * 4))
            yield self.a.toAddr(self.image_base_offset + next_func_offset)
Esempio n. 12
0
class Emulator(object):
    def __init__(self, plugin, state=None, logger_fname="ghidra_emulator.txt"):
        self.plugin = plugin
        self.monitor = self.plugin.getMonitor()
        if state is None:
            state = self.plugin.getGhidraState()
        program = state.getCurrentProgram()
        address = state.getCurrentAddress()

        self.byte_substitution = {}

        self.initLogger(logger_fname)
        self.initEmulator(program, address)
        self.initCmdHandlers()

        self.emulator_state = EmulatorState.WAITING_FOR_PARAM
        self.flatapi = FlatProgramAPI(program)

    def initLogger(self, fname):
        self.logger_fname = fname
        
        self.logger = logging.getLogger(str(random.random()).replace(".","_"))
        self.logger.setLevel(logging.INFO)
        
        h_stdout = logging.StreamHandler(sys.stdout)
        h_stdout.setLevel(logging.INFO)
        self.logger.addHandler(h_stdout)
        if self.logger_fname:
            h_file = logging.FileHandler(self.logger_fname)
            h_file.setLevel(logging.INFO)
            self.logger.addHandler(h_file)

    def initEmulator(self, program, address, clear_param_map=True):
        ''' Setup the emulator helper, symbol maps and fn related stuff '''
        self.program = program
        self.function = self.program.getFunctionManager().getFunctionContaining(address)
        if self.function is None:
            function_name = self.plugin.askString("You are not in a function, please enter an address or a function name", "address or symbol name")
            for f in self.plugin.state.currentProgram.getFunctionManager().getFunctions(True):
                if function == f.getName():
                    self.plugin.state.setCurrentAddress(function.getEntryPoint())
                    self.doStart()
                    return
            for f in self.plugin.state.currentProgram.getFunctionManager().getFunctions(True):
                if int(function, 16) == f.getEntryPoint().getOffset():
                    self.plugin.state.setCurrentAddress(function.getEntryPoint())
                    self.doStart()
                    return
        self.entrypoint = self.program.getListing().getInstructionAt(self.function.getEntryPoint())

        self.logger.info("Program: %s" % self.program)
        self.logger.info("Function: %s" % self.function)

        self.decompinterface = DecompInterface()
        self.decompinterface.openProgram(program)
        result = self.decompinterface.decompileFunction(self.function, 0, self.monitor)
        self.highFunction = result.getHighFunction()
        # self.logger.info(result)
        # self.logger.info(self.highFunction)

        self.decompiled = str(result.getCCodeMarkup())
        # self.logger.info("Decompiled: %s" % self.decompiled)

        self.symbolMap = self.highFunction.getLocalSymbolMap()
        # self.logger.info(self.symbolMap)
        if clear_param_map:
            self.parameterMap = {}
        # fuzz = 0

        self.emulatorHelper = EmulatorHelper(self.program)
        self.stackPointer = (((1 << (self.emulatorHelper.getStackPointerRegister().getBitLength() - 1)) - 1) ^ ((1 << (self.emulatorHelper.getStackPointerRegister().getBitLength()//2))-1))    
        self.returnAddressSize = program.getLanguage().getProgramCounter().getBitLength()

        NULL_PTR_RET = 0
        self.emulatorHelper.writeRegister(self.emulatorHelper.getStackPointerRegister(), self.stackPointer)
        self.emulatorHelper.setBreakpoint(self.getStackAddress(NULL_PTR_RET))
        self.emulatorHelper.enableMemoryWriteTracking(True)

        self.emulator_state = EmulatorState.WAITING_FOR_PARAM
        if not clear_param_map:
            self.emulator_state = EmulatorState.READY
        self.history = []        

        self.lastAddresses = []
        # self.emulatorHelper.getEmulator().executeInstruction = executeInstruction

        self.hookExternalFunctions()
        # def nopCallBack(BreakCallBack):
        #     def __init__(self):
        #         # BreakCallBack.__init__(self)
        #         pass
        #     def pcodeCallback(self, op):
        #         return True
        # help(nopCallBack)
        # emulatorHelper.registerCallOtherCallback('HintPreloadData', nopCallBack(BreakCallBack()))

    def hookExternalFunctions(self):
        for externalFunction in list(self.program.getFunctionManager().getExternalFunctions()):
            self.logger.debug('Found external function `%s`' % (externalFunction.getName()))
            for library in lib.exports:
                self.logger.debug('Found library `%s`' % (library.name))
                for function in library.exports:
                    self.logger.debug('Found function `%s`' % (function.__name__))
                    if externalFunction.getName() == function.__name__:
                        for address in externalFunction.getFunctionThunkAddresses():
                            self.logger.info('Hooked function `%s`@%s with implementation lib/%s/%s' % (externalFunction.getName(), str(address), library.name, function.__name__))
                            callback = DucktapeBreakCallback(function(self.program, self, self.program.getFunctionManager().getFunctionAt(address), self.monitor), lambda x: True)
                            # callback.addressCallback = function(self.program, self, self.program.getFunctionManager().getFunctionAt(address), self.monitor)
                            self.emulatorHelper.emulator.getBreakTable().registerAddressCallback(address, callback)
                            # self.emulatorHelper.setBreakpoint(address)
                        # break
        for thunkFunction in list(filter(lambda x: x.isThunk(), self.program.getFunctionManager().getFunctions(True))):
            for library in lib.exports:
                self.logger.debug('Found library `%s`' % (library.name))
                for function in library.exports:
                    self.logger.debug('Found function `%s`' % (function.__name__))
                    if thunkFunction.getName() == function.__name__:
                        address = thunkFunction.getEntryPoint() 
                        self.logger.info('Hooked function `%s` at %s with implementation lib/%s/%s' % (thunkFunction.getName(), str(address), library.name, function.__name__))
                        callback = DucktapeBreakCallback(function(self.program, self, self.program.getFunctionManager().getFunctionAt(address), self.monitor), lambda x: True)
                        # callback.addressCallback = function(self.program, self, self.program.getFunctionManager().getFunctionAt(address), self.monitor)
                        self.emulatorHelper.emulator.getBreakTable().registerAddressCallback(address, callback)
                            
        
    def initFunctionParameters(self, bytesValueBuffer=""):
        ''' Setup fn input parameters '''
        self.input_wildcards = []
        self.fnParametersAllBytesValue = ""
        for parameter in [self.symbolMap.getParam(i) for i in range(self.symbolMap.getNumParams())]:
            psize = self.parameterStorageSize(parameter)
            if len(bytesValueBuffer) < psize*2:
                bytesValueBuffer = self.plugin.askString('Setting Parameters for `{}` (size: {})'.format(parameter.name, psize), 'byte values')
            bytesValue = bytesValueBuffer[:psize*2]
            bytesValue = (bytesValue + "00"*psize)[:psize*2]
            assert(len(bytesValue) == psize*2)

            for i in range(0,len(bytesValue), 2):
                if bytesValue[i] in string.hexdigits and bytesValue[i+1] in string.hexdigits: continue
                self.input_wildcards.append(bytesValue[i:i+2])
            
            self.parameterMap[parameter.name] = bytesValue
            self.fnParametersAllBytesValue += bytesValue
            
            bytesValueBuffer = bytesValueBuffer[psize*2:]
        # self.logger.info(self.parameterMap)
        if self.input_wildcards:
            self.logger.info("Found %d wildcards: %s" % (len(self.input_wildcards), self.input_wildcards))
            self.logger.info("The next batch of cmds will be executed in fuzzing mode")
        
        for w in self.input_wildcards:
            self.byte_substitution[w] = "00"
        
        self.emulator_state = EmulatorState.READY

    # @staticmethod
    def parameterStorageSize(self, parameter):
        return sum(map(lambda x: x.getSize(), parameter.getStorage().getVarnodes()))

    def getAddress(self, offset):
        return self.program.getAddressFactory().getDefaultAddressSpace().getAddress(offset)

    def getStackAddress(self, offset):
        address = self.getAddress(self.emulatorHelper.readRegister(self.emulatorHelper.getStackPointerRegister()) + offset)
        orAddress = self.getAddress(self.stackPointer + offset)
        self.logger.debug('Stack address at {} or {}'.format(address, orAddress))
        return orAddress

    def writeStackValue(offset, size, value):
        bytesValue = long_to_bytes(value, size)
        if not self.emulatorHelper.getLanguage().isBigEndian():
            bytesValue = bytesValue[::-1]
        self.emulatorHelper.writeMemory(self.getStackAddress(offset), bytesValue)

    def applyByteSubstitution(self, bytesValue):
        for k,v in self.byte_substitution.items():
            bytesValue = bytesValue.replace(k, v)
        return bytesValue.decode('hex')

    def start(self, byte_substitution=None):
        ''' Write the fn inputs in memory (eventually applying the byte substitution) and 
            start the emulation, breaking at fn entry point'''
        assert(self.emulator_state == EmulatorState.READY)
        if byte_substitution is not None:
            self.byte_substitution = byte_substitution

        self.logger.info('Started with byte_sub: %r' % self.byte_substitution)
        
        for parameter in [self.symbolMap.getParam(i) for i in range(self.symbolMap.getNumParams())]:
            bytesValue = self.parameterMap[parameter.name]
            bytesValue = self.applyByteSubstitution(bytesValue)
            storage = parameter.getStorage()
            offset = 0
            for varnode in storage.getVarnodes():
                chunk = bytesValue[offset:offset+varnode.getSize()]
                if varnode.getAddress().isStackAddress():
                    self.emulatorHelper.writeMemory(self.getStackAddress(varnode.getAddress().getOffset()), chunk)
                else:
                    self.emulatorHelper.writeMemory(varnode.getAddress(), chunk)
                offset += varnode.getSize()

        self.emulatorHelper.setBreakpoint(self.function.getEntryPoint())
        self.emulatorHelper.run(self.function.getEntryPoint(), self.entrypoint, self.monitor)

        self.emulator_state = EmulatorState.EXECUTING

    def executeCmds(self, cmds):
        assert(self.emulator_state == EmulatorState.EXECUTING)
        cmds = cmds.strip().split(', ')
        for cmd_id, cmd in enumerate(cmds):
            cmd = cmd.strip().split()
            if cmd[0] not in self.cmd_handlers:
                self.logger.error("Unknown command %s (%r)" % (cmd[0], cmd))
                self.cmdHelp(cmd)
                break

            res = self.cmd_handlers[cmd[0]](cmd)
            if res: self.last_result = res
            self.updateUI()
        # self.printState()
        self.logger.info('Stopping execution for {} at {:x} with error {}'.format(self.emulatorHelper.getEmulateExecutionState(), self.emulatorHelper.readRegister(self.emulatorHelper.getPCRegister()), self.emulatorHelper.getLastError()))
    
    def printState(self):
        for symbol in self.program.getSymbolTable().getAllSymbols(True):
            symbolObject = symbol.getObject()
            try:
                dataType = symbolObject.getDataType()
                name = symbol.getName()
                if name in self.decompiled and symbol.getAddress():
                    self.logger.debug('Found symbol name={} type={} location={}'.format(name, dataType, symbol.getAddress()))
                    bytesValue = self.emulatorHelper.readMemory(symbol.getAddress(), dataType.getLength())
                    stringValue = bytesValue.tostring()
                    printValue = repr(stringValue) if isPrintable(stringValue) else stringValue.encode('hex')
                    self.logger.info('Variable {} has value `{}`'.format(name, printValue))
            except AttributeError as e:
                self.logger.debug(str(e))
            except Exception as e:
                self.logger.error(str(e))
        
        writeSet = self.emulatorHelper.getTrackedMemoryWriteSet()
        for parameter in self.highFunction.getLocalSymbolMap().getSymbols():
            if parameter.name not in self.decompiled:
                continue
            storage = parameter.getStorage()
            bytesValue = bytearray(0)
            for varnode in storage.getVarnodes():
                if varnode.getAddress().isStackAddress():
                    bytesValue.extend(self.emulatorHelper.readMemory(self.getStackAddress(varnode.getAddress().getOffset()), varnode.getSize()))
                elif writeSet.contains(varnode.getAddress()):
                    bytesValue.extend(self.emulatorHelper.readMemory(varnode.getAddress(), varnode.getSize()))
            stringValue = str(bytesValue)
            printValue = repr(stringValue) if isPrintable(stringValue) else stringValue.encode('hex')
            self.logger.info('Variable `{}` @ `{}` has value `{}`'.format(parameter.name, storage, printValue))
        
        for register in self.emulatorHelper.getLanguage().getRegisters():
            if register.isBaseRegister() and not register.isProcessorContext():
                self.logger.debug(str(register))
                self.logger.debug(str(self.emulatorHelper.readRegister(register)))

        self.logger.debug(str(self.emulatorHelper))
        self.logger.debug(str(self.emulatorHelper.getLanguage()))
        self.logger.debug(str(self.emulatorHelper.getLanguage().getRegisters()))
        self.logger.info(str(['{} = {}'.format(register, self.emulatorHelper.readRegister(register)) for register in self.emulatorHelper.getLanguage().getRegisters() if register.isBaseRegister() and not register.isProcessorContext()]))

        self.logger.info('Stopping execution at {:x}'.format(self.emulatorHelper.readRegister(self.emulatorHelper.getPCRegister())))
        self.logger.debug('Logged writes at {}'.format(self.emulatorHelper.getTrackedMemoryWriteSet()))
    
    def readMemory(self, from_, size):
        bytesValue = bytearray(0)
        bytesValue.extend(self.emulatorHelper.readMemory(self.getAddress(from_), size))
        stringValue = str(bytesValue)
        self.logger.info('Reading from {} (size: {}): {}\n\thex={}'.format(from_, size, repr(stringValue), stringValue.encode("hex")))
        return stringValue


    def readPointer(self, address):
        self.logger.debug('reading %d from address %s' % (self.program.getLanguage().getProgramCounter().getBitLength()//8, str(address)))
        packed = bytearray(0)
        packed.extend(self.emulatorHelper.readMemory(address, self.program.getLanguage().getProgramCounter().getBitLength()//8))
        self.logger.debug('reading `%s` from address' % repr(str(packed)))
        if not self.program.getLanguage().isBigEndian():
            packed = str(packed[::-1])
        self.logger.debug('got pointer at `%s`' % repr(str(packed)))        
        return int(packed.encode('hex'), 16)
    def writeMemory(self, from_, bytesValue):
        bytesValue = applyByteSubstitution(bytesValue)
        self.emulatorHelper.writeMemory(self.getAddress(from_), bytesValue)
    
    def updateUI(self):
        self.plugin.syncView(self.emulatorHelper.getExecutionAddress())

    def initCmdHandlers(self):
        self.cmd_handlers = {
            's': self.cmdStep,
            'c': self.cmdContinue,
            'n': self.cmdNext,
            'b': self.cmdBreakpointAdd,
            'd': self.cmdBreakpointRemove,
            'x': self.cmdSleep,
            'q': self.cmdQuit,
            'r': self.cmdReadMem,
            'w': self.cmdWriteMem,
            'p': self.cmdPrintState,
            'h': self.cmdHelp,
            'l': self.cmdLogHistory,
            'e': self.cmdEval,
            'hook': self.cmdHook,
            'list-hooks': self.cmdListHook,
        }

    
    @history
    def cmdHook(self, cmd):
        '''hook address module.function - replace a function with a python implementation
        e.g. hook 0x40000 libc6.puts
        '''
        address = self.getAddress(int(cmd[1], 16))
        library_name, function_name = cmd[2].split('.')
        thunkedFunction = self.program.getFunctionManager().getFunctionContaining(address)
        for library in lib.exports:
            if library_name == library_name:
                self.logger.debug('Found library `%s`' % (library.name))
                for function in library.exports:
                    self.logger.debug('Found function `%s`' % (function.__name__))
                    if function_name == function.__name__:
                        self.logger.info('Hooked function `%s` at %s with implementation lib/%s/%s' % (thunkedFunction.getName(), str(address), library.name, function.__name__))
                        callback = DucktapeBreakCallback(function(self.program, self, thunkedFunction, self.monitor), lambda x: True)
                        # callback.addressCallback = function(self.program, self, self.program.getFunctionManager().getFunctionAt(address), self.monitor)
                        self.emulatorHelper.emulator.getBreakTable().registerAddressCallback(address, callback)
                        break
    @history
    def cmdListHook(self, cmd):
        '''List available hooks
        '''
        for library in lib.exports:
            self.logger.debug('Found library `%s`' % (library.name))
            for function in library.exports:
                self.logger.debug('Found function `%s`' % (function.__name__))
                self.logger.info('%s.%s - %s' % (library.name, function.__name__, function.__doc__))

    @history
    def cmdStep(self, cmd):
        '''step'''
        self.emulatorHelper.step(self.monitor)
    def run(self, monitor):
        self.emulatorHelper.emulator.setHalt(False)
        while not self.emulatorHelper.emulator.getHalt():
            self.emulatorHelper.step(monitor)
            currentAddress = self.emulatorHelper.getExecutionAddress()
            if len(self.lastAddresses) == 0 or self.lastAddresses[0] != currentAddress:
                self.lastAddresses = [currentAddress] + self.lastAddresses[:1]

    @history
    def cmdContinue(self, cmd):
        '''continue'''
        self.run(self.monitor)
    
    @history
    def cmdNext(self, cmd):
        '''step over/next'''
        address = self.flatapi.getInstructionAfter(self.emulatorHelper.getExecutionAddress()).getAddress()
        self.emulatorHelper.setBreakpoint(address)
        self.run(self.monitor)
        self.emulatorHelper.clearBreakpoint(address)
        
    @history
    def cmdBreakpointAdd(self, cmd):
        '''add breakpoint (`hex_address`)'''
        address = self.getAddress(int(cmd[1], 16))
        self.emulatorHelper.setBreakpoint(address)

    @history
    def cmdBreakpointRemove(self, cmd):
        '''remove breakpoint (`hex_address`)'''
        address = self.getAddress(int(cmd[1], 16))
        self.emulatorHelper.clearBreakpoint(address)

    @history
    def cmdSleep(self, cmd):
        '''sleep (`time(=5)`)'''
        for i in range(10000):
            self.monitor.isCancelled()
        # time.sleep(5 if len(cmd) == 1 else int(cmd[1]))
    
    @history
    def cmdQuit(self, cmd):
        '''quit'''
        self.printState()
        self.updateUI()
        self.emulator_state = EmulatorState.DONE

    @history
    def cmdReadMem(self, cmd):
        '''read memory addr (either `hex_from:hex_to` or `hex_from size`)'''
        if len(cmd) == 3:
            from_ = cmd[1]
            size = int(cmd[2], 16 if "0x" in cmd[2].lower() else 10)
        else:
            from_, to_ = map(lambda x: int(x,16), cmd[1].split(":"))
            size = to_-from_
            from_ = hex(from_)
        return self.readMemory(from_.replace("0x",""), size)
    
    @history
    def cmdWriteMem(self, cmd):
        '''write memory addr (`hex_addr hex_bytes`)'''
        self.writeMemory(cmd[1], cmd[2])
    
    @history
    def cmdPrintState(self, cmd):
        '''print state'''
        self.printState()

    @history
    def cmdEval(self, cmd):
        '''executes your command'''
        exec(' '.join(cmd[1:]))

    def cmdHelp(self, cmd):
        '''help'''
        self.logger.info("Commands:")
        for k,v in self.cmd_handlers.items():
            self.logger.info("\t%s: %s" % (k, v.__doc__))

    def cmdLogHistory(self, cmd):
        '''prints a serialized version of this debugging session'''
        self.logger.debug(self.history)
        self.logger.info("`%s`" % (', '.join(self.history)))
Esempio n. 13
0
def main():
    fapi = FlatProgramAPI(currentProgram)

    te_list = fapi.getDataTypes("TraceEntry")

    if len(te_list) == 0:
        create_trace_entry()
        te_list = fapi.getDataTypes("TraceEntry")

        if len(te_list) == 0:
            print("ERROR: failed to create TraceEntry data type")
            return

    te = te_list[0]

    caddr = fapi.toAddr(0)

    monitor.setCancelEnabled(True)
    monitor.setIndeterminate(True)

    trace_entry_addrs = []
    while caddr is not None and not monitor.isCancelled():
        caddrs = fapi.findBytes(caddr, "DBT:", 1000)

        if not caddrs:
            break

        trace_entry_addrs += caddrs

        caddr = caddrs[-1].add(1)
        monitor.setMessage("Found %d TraceEntries" % len(trace_entry_addrs))

        if DEBUG_FIND and len(trace_entry_addrs) > DEBUG_FIND_MAX_ENTRIES:
            break

    print("Found %d TraceEntry structures" % len(trace_entry_addrs))

    # Uncomment if you just want to dump the entries
    #dump_trace_entries(fapi, "trace-entries.txt", trace_entry_addrs); return

    maximum = len(trace_entry_addrs)
    monitor.setIndeterminate(False)
    monitor.initialize(maximum)
    monitor.setCancelEnabled(True)

    monitor.setProgress(0)
    monitor.setMessage("Typing TraceEntries...")

    for cur, caddr in enumerate(trace_entry_addrs):
        if monitor.isCancelled():
            break

        monitor.incrementProgress(1)

        if DEBUG_RETYPE and cur < DEBUG_RETYPE_SKIP:
            continue

        try:
            fapi.clearListing(caddr, caddr.add(te.getLength()))

            trace_entry = fapi.createData(caddr, te)
            message_field = trace_entry.getComponent(4)
            file_field = trace_entry.getComponent(6)

            message_addr = message_field.getValue()
            message_str = getDataAt(message_addr)

            file_addr = file_field.getValue()
            file_str = getDataAt(file_addr)

            if message_str is None or not isinstance(message_str.getValue(),
                                                     unicode):
                message_str = force_create_string(fapi, message_addr, 1024)
                if message_str is None:
                    continue
                if SHOW_OUTPUT:
                    print("[%d/%d] Created missing message string @ %s" %
                          (cur + 1, maximum, message_addr))

            if file_str is None or not isinstance(file_str.getValue(),
                                                  unicode):
                file_str = force_create_string(fapi, file_addr, 1024)
                if file_str is None:
                    continue
                if SHOW_OUTPUT:
                    print("[%d/%d] Created missing file string @ %s" %
                          (cur + 1, maximum, file_addr))

            #print(type(message_str.getValue()), message_str.getValue())

            message_str = message_str.getValue()
            file_str = file_str.getValue()

            message_str_fixed = fixup_format_string(message_str)

            #print("[%s] %s" % (file_str, message_str_fixed))
            base_file = os.path.basename(file_str).split(".")[0]
            symbol_name = "TraceEntry::%s::%s" % (base_file, message_str_fixed)
            # limit the length
            symbol_name_assign = symbol_name[:min(len(symbol_name), 60)]
            fapi.createLabel(caddr, symbol_name_assign, True,
                             SourceType.USER_DEFINED)

            if SHOW_OUTPUT:
                print("[%d/%d] [%s] %s" %
                      (cur + 1, maximum, caddr, symbol_name))
        except ghidra.program.model.mem.MemoryAccessException:
            # this happens with the false positive match of "DBT:String too long"
            print("[%d/%d] [%s] Invalid TraceEntry signature!" %
                  (cur + 1, maximum, caddr))
        except ghidra.program.model.util.CodeUnitInsertionException:
            print("[%d/%d] Something else already at %s" %
                  (cur + 1, maximum, caddr))
            # Uncomment this raise, if you want to catch all strangeness
            # raise

    print('Done!')
Esempio n. 14
0
class MipsRop(object):
    def __init__(self, program):
        self._flat_api = FlatProgramAPI(program)
        self._currentProgram = program
        self.controllable_calls = []
        self.controllable_terminating_calls = []
        self._find_controllable_calls()

    def find_instructions(self, instructions, preserve_register=None,
                          controllable_calls=True, terminating_calls=True,
                          overwrite_register=None):
        """
        Search for gadgets that contain user defined instructions.

        :param instructions: List of instructions to search for.
        :type instructions: list(MipsInstruction)

        :param preserve_register: Registers to preserve.
        :type preserve_register: str

        :param controllable_calls: Search within controllable jumps.
        :type controllable_calls: bool

        :param terminating_calls: Search within controllable function epilogues.
        :type terminating_calls: bool

        :param overwrite_register: Register to ensure is overwritten.
        :param overwrite_register: str

        :returns: List of rop gadgets that contain the provided instructions.
        :rtype: list(RopGadgets)
        """
        gadgets = RopGadgets()

        search_calls = []
        if controllable_calls:
            search_calls.extend(self.controllable_calls)
        if terminating_calls:
            search_calls.extend(self.controllable_terminating_calls)

        for call in search_calls:
            rop = self._find_instruction(
                call, instructions, preserve_register, overwrite_register)
            if rop:
                gadgets.append(RopGadget(rop, call))

        return gadgets

    def find_doubles(self):
        """
        Find double jumps.

        :returns: List of double jump gadgets.
        :rtype: DoubleGadgets
        """
        controllable = self.controllable_calls + \
            self.controllable_terminating_calls

        gadgets = DoubleGadgets()
        for i, call in enumerate(controllable):
            for j in range(i + 1, len(controllable)):
                second_call = controllable[j]
                second_call_addr = second_call.control_instruction.getAddress()
                distance = second_call_addr.subtract(call.call.getAddress())

                # Search for a distance of no more than 25 instructions.
                if 0 < distance <= 100:
                    # If the jumps are in different functions do not return
                    # them
                    func1 = self._flat_api.getFunctionContaining(
                        second_call.call.getAddress())
                    func2 = self._flat_api.getFunctionContaining(
                        call.call.getAddress())
                    if func1 != func2:
                        continue

                    if call.get_source_register() == \
                            second_call.get_source_register():
                        continue

                    if not self._contains_bad_calls(call, second_call):
                        gadgets.append(DoubleGadget(call, second_call))

        return gadgets

    def summary(self):
        """
        Search for book marks that start with 'rop' and print a summary of the 
        ROP gadgets. Case of 'rop' is not important. 
        """
        bookmark_manager = self._currentProgram.getBookmarkManager()
        bookmarks = bookmark_manager.getBookmarksIterator()

        saved_bookmarks = []

        for bookmark in bookmarks:
            comment = bookmark.getComment().lower()
            if comment.startswith('rop'):
                for saved in saved_bookmarks:
                    if saved.getComment().lower() == comment:
                        print 'Duplicate bookmark found: {} at {} and {}'.format(
                            comment, saved.getAddress(), bookmark.getAddress())
                        return
                saved_bookmarks.append(bookmark)

        saved_bookmarks = sorted(saved_bookmarks,
                                 key=lambda x: x.comment.lower())

        rop_gadgets = RopGadgets()

        # Go through each bookmark, find the closest controllable jump, and
        # create a gadget.
        for bookmark in saved_bookmarks:
            closest_jmp = self._find_closest_controllable_jump(
                bookmark.getAddress())

            if bookmark.getComment().lower().endswith('_d'):
                next_closest = self._find_closest_controllable_jump(
                    closest_jmp.call.getAddress())
                if closest_jmp and next_closest:
                    # Hack to change the "control" instruction in case the
                    # bookmark was placed at a different location.
                    updated_ctrl = self._flat_api.getInstructionAt(
                        bookmark.getAddress())
                    closest_jmp.control_instruction = updated_ctrl

                    rop_gadgets.append(DoubleGadget(closest_jmp, next_closest,
                                                    bookmark.getComment()))
            elif closest_jmp:
                curr_addr = bookmark.getAddress()
                curr_ins = self._flat_api.getInstructionAt(curr_addr)
                rop_gadgets.append(RopGadget(curr_ins, closest_jmp,
                                             bookmark.getComment()))
        rop_gadgets.print_summary()

    def _find_closest_controllable_jump(self, address):
        """
        Find closest controllable jump to the address provided.

        :param address: Address to find closest jump to.
        :type address: ghidra.program.model.address.Address

        :returns: Closest controllable jump, if it exists.
        :rtype: ControllableCall or None
        """
        controllable = self.controllable_calls + \
            self.controllable_terminating_calls

        function = self._flat_api.getFunctionContaining(address)

        closest = None

        for jump in controllable[1:]:
            jump_function = self._flat_api.getFunctionContaining(
                jump.call.getAddress())
            if function != jump_function:
                continue

            if address > jump.control_instruction.getAddress():
                continue

            # If the address is a jump do not consider it for the closest jump.
            if jump.call.getAddress() == address:
                continue

            if not closest or \
                    jump.control_instruction.getAddress() <= \
                    address <= jump.call.getAddress():
                closest = jump
            else:
                control_addr = jump.control_instruction.getAddress()
                closest_distances = closest.control_instruction.getAddress()
                if control_addr.subtract(closest_distances) > \
                        control_addr.subtract(address):
                    closest = jump
        return closest

    def _find_controllable_calls(self):
        """
        Find calls that can be controlled through saved registers.
        """
        program_base = self._currentProgram.getImageBase()

        code_manager = self._currentProgram.getCodeManager()
        instructions = code_manager.getInstructions(program_base, True)

        # Loop through each instruction in the current program.
        for ins in instructions:
            flow_type = ins.getFlowType()

            # jalr t9 and some jr t9 are isCall()
            # jr ra is isTerminal()
            # some jr t9 are isJump() && isComputed().
            if flow_type.isCall() or flow_type.isTerminal() or \
                    (flow_type.isJump() and flow_type.isComputed()):
                current_instruction = self._flat_api.getInstructionAt(
                    ins.getAddress())
                controllable = self._find_controllable_call(
                    current_instruction)

                # Sort the controllable jump by type. Makes finding indirect
                # function calls easier.
                if controllable:
                    if flow_type.isCall() and not flow_type.isTerminal():
                        self.controllable_calls.append(controllable)
                    elif flow_type.isTerminal() or \
                            (flow_type.isJump() and flow_type.isComputed()):
                        self.controllable_terminating_calls.append(
                            controllable)

    def _find_controllable_call(self, call_instruction):
        """
        Search for how the jump register is set. If it comes from a potentially
        controllable register then return it.

        :param call_instruction: Instruction that contains a call.
        :type instruction: ghidra.program.mdel.listing.Instruction

        :returns: Controllable call object if controllable, None if not.
        :rtype: ControllableCall or None
        """
        t9_move = MipsInstruction('.*move', 't9', '[sva][012345678]')
        ra_load = MipsInstruction('.*lw', 'ra')

        call_from = call_instruction.getOpObjects(0)[0]

        # No need to check the delay slot so start working back up.
        controllable = None
        previous_ins = self._get_previous_instruction(call_instruction)

        while previous_ins:
            # NOPs are handled weirdly, they have no "flow" so just skip it.
            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            first_op = previous_ins.getOpObjects(0)
            if len(first_op):
                dest_reg = first_op[0]
                if str(dest_reg) == str(call_from):
                    if instruction_matches(previous_ins,
                                           [t9_move, ra_load]):
                        return ControllableCall(call_instruction, previous_ins)
                    return None

            previous_ins = self._get_previous_instruction(previous_ins)

    def _get_previous_instruction(self, instruction):
        """
        Get the previous instruction. Check the "flow" first, if not found
        just return the previous memory instruction.

        :param instruction: Instruction to retrieve previous instruction from.
        :type instruction: ghidra.program.model.listing.Instruction
        """
        fall_from = instruction.getFallFrom()
        if fall_from is None:
            previous_ins = instruction.getPrevious()
        else:
            previous_ins = self._flat_api.getInstructionAt(fall_from)

        return previous_ins

    def _find_instruction(self, controllable_call, search_instructions,
                          preserve_reg=None, overwrite_reg=None):
        """
        Search for an instruction within a controllable call. 

        :param controllable_call: Controllable call to search within.
        :type controllable_call: ControllableCall

        :param search_instructions: Instruction list to search for.
        :type search_instructions: list(MipsInstruction)

        :param preserve_reg: Register to preserve, if overwritten the 
                             instruction will not be returned.
        :type preserve_reg: str

        :param overwrite_reg: Enforce a register was overwritten.
        :type overwrite_reg: str

        :returns: The matching instruction if found, None otherwise.
        :rtype: ghidra.program.model.listing.Instruction
        """
        overwritten = False

        delay_slot = controllable_call.call.getNext()
        if instruction_matches(delay_slot, search_instructions):
            return delay_slot

        previous_ins = self._get_previous_instruction(controllable_call.call)

        while previous_ins:
            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            if instruction_matches(previous_ins, search_instructions):
                if overwrite_reg and not overwritten:
                    return None
                return previous_ins

            if preserve_reg and \
                    register_overwritten(previous_ins, preserve_reg):
                return None

            if overwrite_reg and register_overwritten(previous_ins,
                                                      overwrite_reg):
                overwritten = True

            # TODO: Need to see if we passed the point of caring.
            if register_overwritten(previous_ins,
                                    controllable_call.control_instruction):
                return None

            if is_jump(previous_ins):
                return check_delay_slot(previous_ins, search_instructions)

            previous_ins = self._get_previous_instruction(previous_ins)

        return None

    def _contains_bad_calls(self, first, second):
        """
        Search for bad calls between two controllable jumps.

        :param first: Controllable call that comes first in memory.
        :type first: ControllableCall

        :param second: Controllable call that comes second in memory.
        :type second ControllableCall

        :returns: True if bad calls are found, False otherwise.
        :rtype: bool
        """
        jump = MipsInstruction('j.*')
        branch = MipsInstruction('b.*')

        preserve_reg = str(second.control_instruction.getOpObjects(1)[-1])
        end_ins = first.call

        previous_ins = self._get_previous_instruction(
            second.control_instruction)

        while previous_ins.getAddress() > end_ins.getAddress():
            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            if instruction_matches(previous_ins, [jump, branch]):
                return True

            if register_overwritten(previous_ins, preserve_reg):
                return True

            previous_ins = self._get_previous_instruction(previous_ins)

        return False
Esempio n. 15
0
class Rizzo(object):
    def __init__(self, program):
        self._program = program
        self._flat_api = FlatProgramAPI(self._program)
        self._memory_map = self._program.getMemory()
        self._simple_blk = BasicBlockModel(self._program)
        self._monitor = self._flat_api.getMonitor()
        self._function_manager = self._program.getFunctionManager()
        self._address_factory = self._program.getAddressFactory()

        self.signatures = None
        self._strings = {}
        self._find_strings()

        start = time.time()
        self._signatures = self._generate()
        end = time.time()

        print 'Generated %d formal signatures and %d fuzzy signatures for %d '\
            'functions in %.2f seconds.' % (len(self._signatures.formal),
                                            len(self._signatures.fuzzy),
                                            len(self._signatures.functions),
                                            end - start)

    def save(self, signature_file):
        """
        Save Rizzo signatures to the supplied signature file.

        :param signature_file: Full path to save signatures.
        :type signature_file: str
        """
        print 'Saving signature to %s...' % signature_file
        with open(signature_file, 'wb') as rizz_file:
            pickle.dump(self._signatures, rizz_file)
        print 'done.'

    def load(self, signature_file):
        """
        Load Rizzo signatures from a file.

        :param signature_file: Full path to load signatures from.
        :type signature_file: str

        :returns: Loaded signatures
        :rtype: RizzoSignatures
        """
        if not os.path.exists(signature_file):
            raise Exception('Signature file %s does not exist' %
                            signature_file)

        print 'Loading signatures from %s...' % signature_file
        with open(signature_file, 'rb') as rizz_file:
            try:
                signatures = pickle.load(rizz_file)
            except:
                print 'This does not appear to be a Rizzo signature file.'
                exit(1)
        print 'done.'
        return signatures

    def apply(self, signatures):
        """
        Apply signatures to the current program.

        :param signatures: Signatures to apply to current program.
        :type signatures: RizzoSignatures
        """
        rename_count = 0
        signature_matches = self._find_match(signatures)
        renamed = []

        for matches in signature_matches:
            for curr_func, new_func in matches.iteritems():
                curr_addr = self._address_factory.getAddress(
                    hex(curr_func.address)[:-1])
                function = self._flat_api.getFunctionAt(curr_addr)
                if function and new_func.name not in renamed:
                    renamed.append(new_func.name)
                    if self._rename_functions(function, new_func.name):
                        rename_count += 1

                duplicates = []
                block_match = {}
                for block in new_func.blocks:
                    new_block = RizzoBlockDescriptor(block)
                    for curr_block in curr_func.blocks:
                        curr_block = RizzoBlockDescriptor(curr_block)

                        if curr_block == new_block:
                            if curr_block in block_match:
                                del block_match[curr_block]
                                duplicates.append(curr_block)
                            elif curr_block not in duplicates:
                                block_match[curr_block] = new_block

                for curr_block, new_block in block_match.iteritems():
                    for curr_function, new_function in \
                            zip(curr_block.functions, new_block.functions):
                        functions = utils.find_function(
                            self._program, curr_function)
                        if len(functions) == 1:
                            if new_function not in renamed:
                                renamed.append(new_function)
                                if self._rename_functions(
                                        functions[0], new_function):
                                    rename_count += 1

        print 'Renamed %d functions.' % rename_count

    def _find_match(self, signatures):
        """
        Find matches to signatures in the current program.

        :param signatures: Signatures to find in current program.
        :type signatures: RizzoSignatures

        :returns: Tuple of matched signatures: (formal, string, immediate, fuzzy)
        :rtype: tuple
        """
        formal_signatures = find_signature_matches(signatures.formal,
                                                   self._signatures.formal,
                                                   signatures.functions,
                                                   self._signatures.functions,
                                                   'formal signatures')

        string_signatures = find_signature_matches(signatures.strings,
                                                   self._signatures.strings,
                                                   signatures.functions,
                                                   self._signatures.functions,
                                                   'string signatures')

        immediate_signatures = find_signature_matches(
            signatures.immediates, self._signatures.immediates,
            signatures.functions, self._signatures.functions,
            'immediate signatures')

        fuzzy_signatures = find_signature_matches(
            signatures.fuzzy, self._signatures.fuzzy, signatures.functions,
            self._signatures.functions, 'fuzzy signatures',
            lambda x, y: len(x.blocks) == len(y.blocks))

        return (formal_signatures, string_signatures, immediate_signatures,
                fuzzy_signatures)

    def _rename_functions(self, function, name):
        """
        Rename a function if the function has not be renamed and new name
        is a valid new function name. Previous renamed are determined by 
        searching for 'FUN_' in the function.

        :param function: Function to be renamed.
        :type function: ghidra.program.model.listing.Function

        :param name: New name to give function.
        :type name: unicode

        :returns: True if function renamed, False for no rename.
        :rtype: bool
        """
        if not function or not name:
            return False

        if 'FUN_' in function.name and 'FUN_' not in name:
            if function:
                print 'Renaming %s to %s' % (function.name, name)
                function.setName(name, SourceType.USER_DEFINED)
                return True
        elif 'FUN_' not in function.name and 'FUN_' not in name and \
                function.name != name:
            print 'Found match with %s to %s but did not rename.' % \
                (function.name, name)
        return False

    def _signature_hash(self, value):
        """
        Simple hash function used to create a signature.

        :param value: Value to hash.
        :type value: variable

        :returns: Signature hash
        :rtype: int
        """
        return hash(str(value)) & 0xFFFFFFFF

    def _find_strings(self):
        """
        Find strings in the current program and create signatures for them.
        """
        memory = self._memory_map.getAllInitializedAddressSet()
        strings = self._flat_api.findStrings(memory, 2, 1, True, True)

        for string in strings:
            addr = string.getAddress()
            value = string.getString(self._memory_map)
            xref = self._flat_api.getReferencesTo(addr)
            self._strings[addr.hashCode()] = RizzoString(addr, value, xref)

    def _get_function_blocks(self, function):
        """
        Get all code blocks in the provided function.

        :param function: Function to get code blocks from.
        :type function: ghidra.program.model.listing.Function

        :returns: List of code blocks.
        :rtype: ghidra.program.model.block.CodeBlock
        """
        blocks = []
        code_blocks = self._simple_blk.getCodeBlocksContaining(
            function.body, self._monitor)

        while code_blocks.hasNext():
            blocks.append(code_blocks.next())

        return blocks

    def _hash_block(self, block):
        """
        Create signatures for the provided code block.

        :returns: Tuple of formal, fuzzy, function, and immediate signatures)
        """
        formal = []
        fuzzy = []
        functions = []
        immediates = []

        min_addr = block.minAddress
        max_addr = block.maxAddress

        curr_ins = self._flat_api.getInstructionAt(min_addr)

        while curr_ins and curr_ins.getAddress() < max_addr:
            code_ref = []
            data_ref = []

            # Create code and data reference signatures.
            references = curr_ins.getReferencesFrom()
            for reference in references:
                # Don't care about tracking stack references.
                if reference.isStackReference():
                    continue

                if is_code_ref(reference):
                    code_ref.append(reference)

                # Get data reads only if they are to valid memory.
                elif is_data_ref(reference) and \
                        self._memory_map.contains(reference.toAddress):
                    data_ref.append(reference)

            # Append the mnemonic string to the formal signature.
            formal.append(curr_ins.getMnemonicString())

            # If its a call instruction add the function call to the functions
            # signature and make note of the call in the fuzzy signature.
            if is_call_instruction(curr_ins):
                for cref in code_ref:
                    func = self._flat_api.getFunctionAt(cref.toAddress)
                    if func:
                        functions.append(func.getName())
                        fuzzy.append('funcref')
            # Make not of any data references.
            elif data_ref:
                for dref in data_ref:
                    addr_hash = dref.toAddress.hashCode()

                    if self._strings.has_key(addr_hash):
                        string_value = self._strings[addr_hash].value
                    else:
                        string_value = 'dataref'

                    formal.append(string_value)
                    fuzzy.append(string_value)
            # If not data or code then add everything to the formal signature.
            elif not data_ref and not code_ref:
                for i in range(0, curr_ins.getNumOperands()):
                    operand = curr_ins.getDefaultOperandRepresentation(i)
                    formal.append(operand)

                    op_type = curr_ins.getOperandRefType(i)
                    if op_type.isData():
                        # Indeterminate return values. Just put a try/except
                        # around it so the getValue AttributeError can be
                        # ignored. Not worth checking for types since those
                        # may come and go.
                        try:
                            op_value = curr_ins.getOpObjects(i)[0].getValue()
                            if op_value > 0xFFFF:
                                fuzzy.append(str(op_value))
                                immediates.append(op_value)
                        except (AttributeError, IndexError):
                            pass

            curr_ins = curr_ins.getNext()

        formal_sig = self._signature_hash(''.join(formal))
        fuzzy_sig = self._signature_hash(''.join(fuzzy))

        return (formal_sig, fuzzy_sig, immediates, functions)

    def _hash_function(self, function):
        """
        Create a block by block signature for the provided function.

        :param function: Function to create signature hash for.
        :type function: ghidra.program.model.listing.Function

        :returns: List of signatures per block found.
        """
        block_hash = []

        func_blocks = self._get_function_blocks(function)
        for block in func_blocks:
            block_hash.append(self._hash_block(block))

        return block_hash

    def _generate(self):
        """
        Create signatures for the current program.
        """
        signatures = RizzoSignature()

        # String based signatures
        for (str_hash, curr_string) in self._strings.iteritems():
            # Only create signatures on reasonably long strings with one ref.
            if len(curr_string.value) >= 8 and len(curr_string.xrefs) == 1:
                function = self._flat_api.getFunctionContaining(
                    curr_string.xrefs[0].fromAddress)
                if function:
                    string_hash = self._signature_hash(curr_string.value)
                    entry = utils.address_to_int(function.getEntryPoint())
                    signatures.add_string(string_hash, entry)

        # Formal, fuzzy, and immediate-based function signatures
        for function in self._function_manager.getFunctions(True):
            hashed_function_blocks = self._hash_function(function)

            formal = self._signature_hash(''.join(
                [str(e) for (e, _, _, _) in hashed_function_blocks]))
            fuzzy = self._signature_hash(''.join(
                [str(f) for (_, f, _, _) in hashed_function_blocks]))
            immediate = [str(i) for (_, _, i, _) in hashed_function_blocks]

            function_entry = utils.address_to_int(function.getEntryPoint())
            signatures.functions[function_entry] = (function.getName(),
                                                    hashed_function_blocks)

            signatures.add_formal(formal, function_entry)
            signatures.add_fuzzy(fuzzy, function_entry)

            for value in immediate:
                signatures.add_immediate(value, function_entry)

        signatures.reset_dups()

        return signatures
Esempio n. 16
0
class ArmRop(object):
    def __init__(self, program):
        self._flat_api = FlatProgramAPI(program)
        self._currentProgram = program
        self.controllable_calls = []
        self.controllable_terminating_calls = []
        self._find_controllable_calls()

    def find_instructions(self,
                          instructions,
                          preserve_register=None,
                          controllable_calls=True,
                          terminating_calls=True,
                          overwrite_register=None):
        """
        Search for gadgets that contain user defined instructions.

        :param instructions: List of instructions to search for.
        :type instructions: list(MipsInstruction)

        :param preserve_register: Registers to preserve.
        :type preserve_register: str

        :param controllable_calls: Search within controllable jumps.
        :type controllable_calls: bool

        :param terminating_calls: Search within controllable function epilogues.
        :type terminating_calls: bool

        :param overwrite_register: Register to ensure is overwritten.
        :param overwrite_register: str

        :returns: List of rop gadgets that contain the provided instructions.
        :rtype: list(RopGadgets)
        """
        gadgets = RopGadgets()

        search_calls = []
        if controllable_calls:
            search_calls.extend(self.controllable_calls)
        if terminating_calls:
            search_calls.extend(self.controllable_terminating_calls)

        for call in search_calls:
            rop = self._find_instruction(call, instructions, preserve_register,
                                         overwrite_register)
            if rop and self._is_valid_action(call, rop):
                gadgets.append(RopGadget(rop, call))

        return gadgets

    def find_doubles(self):
        """
        Find double jumps.

        :returns: List of double jump gadgets.
        :rtype: DoubleGadgets
        """
        controllable = self.controllable_calls
        terminating = self.controllable_terminating_calls

        gadgets = DoubleGadgets()
        for call in controllable:
            for second_call in terminating:
                second_call_addr = second_call.control_instruction.getAddress()
                distance = second_call_addr.subtract(call.call.getAddress())

                # Search for a distance of no more than 25 instructions.
                if 0 < distance <= 100:
                    # If the jumps are in different functions do not return
                    # them
                    func1 = self._flat_api.getFunctionContaining(
                        second_call.call.getAddress())
                    func2 = self._flat_api.getFunctionContaining(
                        call.call.getAddress())
                    if func1 != func2:
                        continue

                    if not self._contains_bad_calls(call, second_call):
                        gadgets.append(DoubleGadget(call, second_call))

        return gadgets

    def summary(self):
        """
        Search for book marks that start with 'rop' and print a summary of the 
        ROP gadgets. Case of 'rop' is not important. 
        """
        bookmark_manager = self._currentProgram.getBookmarkManager()
        bookmarks = bookmark_manager.getBookmarksIterator()

        saved_bookmarks = []

        for bookmark in bookmarks:
            comment = bookmark.getComment().lower()
            if comment.startswith('rop'):
                for saved in saved_bookmarks:
                    if saved.getComment().lower() == comment:
                        print 'Duplicate bookmark found: {} at {} and {}'.format(
                            comment, saved.getAddress(), bookmark.getAddress())
                        return
                saved_bookmarks.append(bookmark)

        saved_bookmarks = sorted(saved_bookmarks,
                                 key=lambda x: x.comment.lower())

        rop_gadgets = RopGadgets()

        # Go through each bookmark, find the closest controllable jump, and
        # create a gadget.
        for bookmark in saved_bookmarks:
            closest_jmp = self._find_closest_controllable_jump(
                bookmark.getAddress())

            if closest_jmp:
                curr_addr = bookmark.getAddress()
                curr_ins = self._flat_api.getInstructionAt(curr_addr)
                rop_gadgets.append(
                    RopGadget(curr_ins, closest_jmp, bookmark.getComment()))
        rop_gadgets.print_summary()

    def _find_closest_controllable_jump(self, address):
        """
        Find closest controllable jump to the address provided.

        :param address: Address to find closest jump to.
        :type address: ghidra.program.model.address.Address

        :returns: Closest controllable jump, if it exists.
        :rtype: ControllableCall or None
        """
        controllable = self.controllable_calls + \
            self.controllable_terminating_calls

        function = self._flat_api.getFunctionContaining(address)

        closest = None

        for jump in controllable:
            jump_function = self._flat_api.getFunctionContaining(
                jump.call.getAddress())
            if function != jump_function:
                continue

            if address > jump.control_instruction.getAddress():
                continue

            if jump.call.getAddress() == address:
                return jump

            if not closest or \
                    jump.control_instruction.getAddress() <= \
                    address <= jump.call.getAddress():
                closest = jump
            else:
                control_addr = jump.control_instruction.getAddress()
                closest_distances = closest.control_instruction.getAddress()
                if control_addr.subtract(address) < \
                        closest_distances.subtract(address):
                    closest = jump
        return closest

    def _find_controllable_calls(self):
        """
        Find calls that can be controlled through saved registers.
        """
        program_base = self._currentProgram.getImageBase()

        code_manager = self._currentProgram.getCodeManager()
        instructions = code_manager.getInstructions(program_base, True)

        # Loop through each instruction in the current program.
        for ins in instructions:
            flow_type = ins.getFlowType()

            if flow_type.isCall() or flow_type.isTerminal() or \
                    (flow_type.isJump() and flow_type.isComputed()):
                current_instruction = self._flat_api.getInstructionAt(
                    ins.getAddress())
                controllable = self._find_controllable_call(
                    current_instruction)

                # Sort the controllable jump by type. Makes finding indirect
                # function calls easier.
                if controllable:
                    if flow_type.isCall() and not flow_type.isTerminal():
                        self.controllable_calls.append(controllable)
                    elif flow_type.isTerminal() or \
                            (flow_type.isJump() and flow_type.isComputed()):
                        self.controllable_terminating_calls.append(
                            controllable)

    def _find_controllable_call(self, call_instruction):
        """
        Search for how the jump register is set. If it comes from a potentially
        controllable register then return it.

        :param call_instruction: Instruction that contains a call.
        :type instruction: ghidra.program.mdel.listing.Instruction

        :returns: Controllable call object if controllable, None if not.
        :rtype: ControllableCall or None
        """
        branch_link = ArmInstruction('blx', '[r][0123456789]')
        branch_exchange = ArmInstruction('bx', '[r][0123456789]')
        end = ArmInstruction('ldmia', 'sp*')

        if instruction_matches(call_instruction,
                               [branch_link, branch_exchange, end]):
            return ControllableCall(call_instruction, call_instruction)
        return None

    def _get_previous_instruction(self, instruction):
        """
        Get the previous instruction. Check the "flow" first, if not found
        just return the previous memory instruction.

        :param instruction: Instruction to retrieve previous instruction from.
        :type instruction: ghidra.program.model.listing.Instruction
        """
        fall_from = instruction.getFallFrom()
        if fall_from is None:
            previous_ins = instruction.getPrevious()
        else:
            previous_ins = self._flat_api.getInstructionAt(fall_from)

        return previous_ins

    def _find_instruction(self,
                          controllable_call,
                          search_instructions,
                          preserve_reg=None,
                          overwrite_reg=None):
        """
        Search for an instruction within a controllable call. 

        :param controllable_call: Controllable call to search within.
        :type controllable_call: ControllableCall

        :param search_instructions: Instruction list to search for.
        :type search_instructions: list(MipsInstruction)

        :param preserve_reg: Register to preserve, if overwritten the 
                             instruction will not be returned.
        :type preserve_reg: str

        :param overwrite_reg: Enforce a register was overwritten.
        :type overwrite_reg: str

        :returns: The matching instruction if found, None otherwise.
        :rtype: ghidra.program.model.listing.Instruction
        """
        overwritten = False

        if instruction_matches(controllable_call.call, search_instructions):
            return controllable_call.call

        previous_ins = self._get_previous_instruction(controllable_call.call)
        function = self._flat_api.getFunctionContaining(
            controllable_call.call.getAddress())

        while previous_ins:
            # Break if we hit a call or jump.
            if utils.is_call_instruction(previous_ins) or \
                    utils.is_jump_instruction(previous_ins):
                return None

            # Break if we entered a different function.
            if function != self._flat_api.getFunctionContaining(
                    previous_ins.getAddress()):
                return None

            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            if instruction_matches(previous_ins, search_instructions):
                if overwrite_reg and not overwritten:
                    return None
                return previous_ins

            if preserve_reg and \
                    register_overwritten(previous_ins, preserve_reg):
                return None

            if overwrite_reg and register_overwritten(previous_ins,
                                                      overwrite_reg):
                overwritten = True

            # TODO: Need to see if we passed the point of caring.
            if register_overwritten(previous_ins,
                                    controllable_call.get_control_item()):
                return None

            previous_ins = self._get_previous_instruction(previous_ins)

        return None

    def _is_valid_action(self, controllable_call, action):
        """
        Determine if an action is valid for the controllable call. 

        :param controllable_call: Controllable call to search within.
        :type controllable_call: ControllableCall

        :param action: Action to validate.
        :type action: ghidra.program.model.listing.Instruction

        :returns: The matching instruction if found, None otherwise.
        :rtype: ghidra.program.model.listing.Instruction
        """
        if controllable_call.call == action:
            return True

        previous_ins = controllable_call.call
        preserve_reg = action.getOpObjects(0)
        if preserve_reg:
            preserve_reg = str(preserve_reg[0])

        while previous_ins and previous_ins != action:
            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            if register_overwritten(previous_ins, preserve_reg):
                return False

            previous_ins = self._get_previous_instruction(previous_ins)

        return True

    def _contains_bad_calls(self, first, second):
        """
        Search for bad calls between two controllable jumps.

        :param first: Controllable call that comes first in memory.
        :type first: ControllableCall

        :param second: Controllable call that comes second in memory.
        :type second ControllableCall

        :returns: True if bad calls are found, False otherwise.
        :rtype: bool
        """
        branch = ArmInstruction('b.*')

        end_ins = first.call

        previous_ins = self._get_previous_instruction(
            second.control_instruction)

        while previous_ins.getAddress() > end_ins.getAddress():
            if 'nop' in str(previous_ins):
                previous_ins = previous_ins.getPrevious()

            if instruction_matches(previous_ins, [branch]):
                return True

            previous_ins = self._get_previous_instruction(previous_ins)

        return False
Esempio n. 17
0
bitness_masks = {
    '16': 0xffff,
    '32': 0xffffffff,
    '64': 0xffffffffffffffff,
}

BINARY_PCODE_OPS = {
    PcodeOp.INT_ADD: '+',
    PcodeOp.PTRSUB: '+',
    PcodeOp.INT_SUB: '-',
    PcodeOp.INT_MULT: '*'
}

cp = currentProgram
fp = FlatProgramAPI(cp)
space_ram = None
space_uniq = None

name2space = {'register': {}, 'unique': {}}


def get_high_function(func):
    options = DecompileOptions()
    monitor = ConsoleTaskMonitor()
    ifc = DecompInterface()
    ifc.setOptions(options)
    ifc.openProgram(getCurrentProgram())
    res = ifc.decompileFunction(func, 60, monitor)
    return res.getHighFunction()
Esempio n. 18
0
 def __init__(self, currentProgram):
     self.p = currentProgram
     self.a = FlatProgramAPI(self.p)
     self.image_base_offset = self.p.getAddressMap().imageBase.getOffset()
Esempio n. 19
0
    def __init__(self, program, selection, monitor, arch, abi='default'):

        self.currentProgram = program
        self.currentSelection = selection
        self.monitor = monitor
        self.flatProgram = FlatProgramAPI(program, monitor)
        self.symEval = SymbolicPropogator(self.currentProgram)

        if self.currentProgram.getExecutableFormat() != ElfLoader.ELF_NAME:
            popup('Not an ELF file, cannot continue')
            return

        if arch not in ARCHS:
            popup('Architecture not defined')
            return

        if abi not in ARCHS[arch]:
            popup('ABI not defined')
            return

        global SYSCALLS, FUNCTIONS
        SYSCALLS = self.loadData('syscalls', arch, abi)
        FUNCTIONS = self.loadData('functions', arch, abi)

        data = ARCHS[arch][abi]
        endian = self.currentProgram.getLanguage().getLanguageDescription(
        ).getEndian().toString()

        for row in data['ins']:

            if row['endian'] != endian:
                continue

            calls = self.getSyscalls(row['opcode'], row['interrupt'])
            for call in calls:

                if self.currentSelection is not None:
                    if call < self.currentSelection.getMinAddress():
                        continue
                    if call > self.currentSelection.getMaxAddress():
                        continue

                reg = self.currentProgram.getRegister(data['reg'])
                res = self.getRegisterValue(call, reg)

                if res is None:
                    continue

                res = str(res)
                if res not in SYSCALLS:
                    continue

                syscall = SYSCALLS[res]
                comment = syscall

                if syscall in FUNCTIONS:
                    comment = self.getSignature(syscall, FUNCTIONS[syscall])
                    self.markArguments(data['arg'], call, FUNCTIONS[syscall])

                self.flatProgram.setEOLComment(call, comment)
                self.flatProgram.createBookmark(
                    call, 'Syscall', 'Found %s -- %s' % (syscall, comment))
Esempio n. 20
0
    def call(self, inputs, output):
        assert len(inputs) >= 1
        # First we have to analyze function forward with input arguments
        # If output exists, then we have to analyze backwards to obtain ret value types
        pc_varnode = inputs[0]
        assert pc_varnode.isAddress()
        pc_addr = pc_varnode.getAddress()
        temp = FlatProgramAPI(currentProgram)
        called_func = temp.getFunctionAt(pc_addr)
        print("call:", inputs[0].getPCAddress())

        ##### START CALL RECURSIVE FORWARD ANALYSIS

        # Note: the function analysis parameter's varnodes are DIFFERENT that the varnodes from our current state. Thus we replace the varnode -> Node map in the function with the calling parameters
        checkFixParameters(called_func, inputs[1:])
        if called_func not in forward_cache:
            global log
            pci_new = PCodeInterpreter()
            parameter_varnodes = analyzeFunctionForward(called_func, pci_new)
            parameter_nodes = []
            for i in parameter_varnodes:
                parameter_nodes.append(pci_new.lookup_node(i)[0])
            forward_cache[called_func] = (pci_new.stores, pci_new.loads,
                                          parameter_nodes, pci_new.arrays,
                                          pci_new.subcall_parameter_cache)
            log = False

        stores, loads, parameter_node_objects, arrs, nested_subcall_parameter_cache = forward_cache[
            called_func]
        input_node_objects = map(self.lookup_node, inputs[1:])
        if called_func not in self.subcall_parameter_cache:
            param_list = []
            for i in range(called_func.getParameterCount()):
                param_list.append([])
            self.subcall_parameter_cache[called_func] = param_list

        node_objects = map(self.lookup_node, inputs[1:])
        for i in range(len(self.subcall_parameter_cache[called_func])):
            self.subcall_parameter_cache[called_func][i] += node_objects[i]

        for i in stores:
            arg_idx = i.find_base_idx(parameter_node_objects)
            if arg_idx is not None:
                for j in node_objects[arg_idx]:
                    self.stores.append(
                        i.replace_base_parameters(parameter_node_objects, j))
                    if i in arrs:
                        self.arrays.append(self.stores[-1])
        for i in loads:
            arg_idx = i.find_base_idx(parameter_node_objects)
            if arg_idx is not None:
                for j in node_objects[arg_idx]:
                    self.loads.append(
                        i.replace_base_parameters(parameter_node_objects, j))
                    if i in arrs:
                        self.arrays.append(self.loads[-1])

        ##### END CALL RECURSIVE FORWARD ANALYSIS

        # replace args in parameter cache:
        for func_name in nested_subcall_parameter_cache:
            current_params = nested_subcall_parameter_cache[func_name]
            for param_idx in range(len(current_params)):
                for temp in current_params[param_idx]:
                    arg_idx = temp.find_base_idx(parameter_node_objects)
                    if arg_idx is not None:
                        for j in node_objects[arg_idx]:
                            replaced = temp.replace_base_parameters(
                                parameter_node_objects, j)
                            if func_name not in self.subcall_parameter_cache:
                                param_list = []
                                for i in range(func_name.getParameterCount()):
                                    param_list.append([])
                                self.subcall_parameter_cache[
                                    func_name] = param_list
                            if arg_idx < len(
                                    self.subcall_parameter_cache[func_name]):
                                self.subcall_parameter_cache[func_name][
                                    arg_idx].append(replaced)

        if output is not None:
            if called_func not in backward_cache:  # This means we want to backwards interpolate the return type
                ##### START CALL RECURSIVE BACKWARDS ANALYSIS

                checkFixReturn(called_func, output)
                pci_new = PCodeInterpreter()
                ret_type, subfunc_parameter_varnodes = analyzeFunctionBackward(
                    called_func, pci_new)
                backward_cache[called_func] = (ret_type,
                                               map(pci_new.lookup_node,
                                                   subfunc_parameter_varnodes))

                ##### END CALL RECURSIVE BACKWARDS ANALYSIS

            ret_type, subfunc_parameter_node_objs = backward_cache[called_func]
            replaced_rets = []
            for a in ret_type:
                for i in a:
                    arg_idx = i.find_base_idx(subfunc_parameter_node_objs)
                    if arg_idx is None:
                        node_objects = [1]  # Doesn't matter
                    else:
                        node_objects = self.lookup_node(inputs[1:][arg_idx])
                    for j in node_objects:
                        replaced_rets.append(
                            i.replace_base_parameters(
                                subfunc_parameter_node_objs, j))

            for i in range(len(replaced_rets)):
                self.store_node(output, replaced_rets[i])
Esempio n. 21
0
 def __init__(self, program):
     self._program = program
     self._flat_api = FlatProgramAPI(program)
     self._monitor = self._flat_api.getMonitor()
     self._basic_blocks = BasicBlockModel(self._program)
Esempio n. 22
0
def declare(program, file):
    tree = ET.parse(file)
    root = tree.getroot()

    api = FlatProgramAPI(program)
    aspace = api.getAddressFactory().getDefaultAddressSpace()

    periph_map = {}

    for periph in root.findall('.//peripheral'):
        periph_name = get_text(periph, 'name', 'UNK')
        periph_base = get_int(periph, 'baseAddress')

        derived_from = periph.get('derivedFrom')
        if derived_from:
            regs = periph_map[derived_from]
        else:
            regs = periph.findall('./registers/register')

        periph_map[periph_name] = regs

        for reg in regs:
            reg_name = get_text(reg, 'name', 'UNK')
            reg_addr = periph_base + get_int(reg, 'addressOffset')
            reg_size = get_int(reg, 'size')

            addr = aspace.getAddress(reg_addr)
            try:
                api.createLabel(addr, "%s_%s" % (periph_name, reg_name), True)

                if reg_size == 8:
                    api.createByte(addr)
                elif reg_size == 16:
                    api.createWord(addr)
                elif reg_size == 32:
                    api.createDWord(addr)
            except:
                print("skipping address 0x%08x" % reg_addr)
Esempio n. 23
0
if not is_hex or not is_even:
    print "Error: Please only enter hex values."
    exit()

instr_bytes = "".join(
    ["\\x" + instr_bytes[i:i + 2] for i in range(0, len(instr_bytes), 2)])

decompInterface = DecompInterface()
decompInterface.openProgram(currentProgram)

# ghidra options
newOptions = DecompileOptions()  # Current decompiler options
newOptions.setMaxPayloadMBytes(max_payload_mbytes)
decompInterface.setOptions(newOptions)
listing = currentProgram.getListing()
fpapi = FlatProgramAPI(currentProgram)
address_factory = fpapi.getAddressFactory()
psedu_disassembler = PseudoDisassembler(currentProgram)

# search for the specified bytes
minAddress = currentProgram.getMinAddress()
instr_addresses = fpapi.findBytes(minAddress, instr_bytes, matchLimit)

fixed = 0
for target_address in instr_addresses:
    # check if ghidra got this one right
    disassembled_instr = fpapi.getInstructionAt(target_address)
    if not disassembled_instr == None:
        continue

    print "found the bytes at: " + str(target_address)
Esempio n. 24
0
from ghidra.program.model.symbol import SourceType
from java.util import ArrayList


TL_APIS = {}
MAX_TL_API = 0xC2

DR_APIS = {}
MAX_DR_API = 0x3D


blockModel = BasicBlockModel(currentProgram)
functionManager = currentProgram.getFunctionManager()
decompInterface = DecompInterface()
decompInterface.openProgram(currentProgram)
api = FlatProgramAPI(currentProgram, monitor)


def read_dword(address):
    dword = api.getBytes(address, 4)
    return struct.unpack('<I', dword)[0]


def load_api_names_types():
    """
    Loads api function names and api function types from external json files and fills globals dict variables with these informations
    """
    curdir = os.path.dirname(os.path.abspath(inspect.getsourcefile(lambda: 0)))

    with open(os.path.join(curdir, 'tl_apis.json'), 'r') as f:
        TL_APIS.update({i: (n, t) for i, n, t in json.loads(f.read())})
Esempio n. 25
0
 def __init__(self, program):
     self._flat_api = FlatProgramAPI(program)
     self._currentProgram = program
     self.controllable_calls = []
     self.controllable_terminating_calls = []
     self._find_controllable_calls()
 def start(self):
     self._program = self._currentProgram
     self._listing = self._program.getListing()
     from ghidra.program.flatapi import FlatProgramAPI
     self._flatapi = FlatProgramAPI(self._program)