예제 #1
0
class CryptoIdentifier():
    """
    This class contains the logic to perform Crypto identification.
    Two techniques are currently supported:
    1. A heuristic approach that identifies functions and basic blocks
    based on the ratio of arithmetic/logic instructions to all instructions
    2. A signature-based approach, using the signatures defined in PatternManager
    """
    def __init__(self):
        self.name = "CryptoIdentifier"
        print("[*] loading CryptoIdentifier")
        self.time = time
        self.re = re
        self.GraphHelper = GraphHelper
        self.CryptoSignatureHit = CryptoSignatureHit
        self.AritlogBasicBlock = AritlogBasicBlock
        self.Segment = Segment
        self.pm = PatternManager()
        self.low_rating_threshold = 0.4
        self.high_rating_threshold = 1.0
        self.low_instruction_threshold = 8
        self.high_instruction_threshold = 100
        # if the threshold is set to this value, it is automatically expanded to infinite.
        self.max_instruction_threshold = 100
        self.low_call_threshold = 0
        self.high_call_threshold = 1
        # if the threshold is set to this value, it is automatically expanded to infinite.
        self.max_call_threshold = 10
        # if at least this fraction of a signature's length' has been identified
        # consecutively, the location is marked as a signature hit.
        self.match_filter_factor = 0.5
        self.aritlog_blocks = []
        self.signature_hits = []
        self.ida_proxy = IdaProxy()
        return

    def scan(self):
        """
        Scan the whole IDB with all available techniques.
        """
        self.scanAritlog()
        self.scanCryptoPatterns()

################################################################################
# Aritlog scanning
################################################################################

    def scanAritlog(self):
        """
        scan with the arithmetic/logic heuristic
        @return: a list of AritLogBasicBlock data objects that fulfill the parameters as specified
        """
        print("[*] CryptoIdentifier: Starting aritlog heuristic analysis.")
        self.aritlog_blocks = []
        time_before = self.time.time()
        for function_ea in self.ida_proxy.Functions():
            function_chart = self.ida_proxy.FlowChart(
                self.ida_proxy.get_func(function_ea))
            calls_in_function = 0
            function_blocks = []
            function_dgraph = {}
            blocks_in_loops = set()
            for current_block in function_chart:
                block = self.AritlogBasicBlock(current_block.startEA,
                                               current_block.endEA)
                for instruction in self.ida_proxy.Heads(
                        block.start_ea, block.end_ea):
                    if self.ida_proxy.isCode(
                            self.ida_proxy.GetFlags(instruction)):
                        mnemonic = self.ida_proxy.GetMnem(instruction)
                        has_identical_operands = self.ida_proxy.GetOperandValue(instruction, 0) == \
                            self.ida_proxy.GetOperandValue(instruction, 1)
                        block.updateInstructionCount(mnemonic,
                                                     has_identical_operands)
                        if mnemonic == "call":
                            calls_in_function += 1
                function_blocks.append(block)
                # prepare graph for Tarjan's algorithm
                succeeding_blocks = [
                    succ.startEA for succ in current_block.succs()
                ]
                function_dgraph[current_block.startEA] = succeeding_blocks
                # add trivial loops
                if current_block.startEA in succeeding_blocks:
                    block.is_contained_in_trivial_loop = True
                    blocks_in_loops.update([current_block.startEA])
            # perform Tarjan's algorithm to identify strongly connected components (= loops) in the function graph
            graph_helper = self.GraphHelper()
            strongly_connected = graph_helper.calculateStronglyConnectedComponents(
                function_dgraph)
            non_trivial_loops = [
                component for component in strongly_connected
                if len(component) > 1
            ]
            for component in non_trivial_loops:
                for block in component:
                    blocks_in_loops.update([block])
            for block in function_blocks:
                if block.start_ea in blocks_in_loops:
                    block.is_contained_in_loop = True
                block.num_calls_in_function = calls_in_function
            self.aritlog_blocks.extend(function_blocks)
        print("[*] Heuristics analysis took %3.2f seconds." %
              (self.time.time() - time_before))

        return self.getAritlogBlocks(
            self.low_rating_threshold, self.high_rating_threshold,
            self.low_instruction_threshold, self.high_instruction_threshold,
            self.low_call_threshold, self.high_call_threshold, False, False,
            False)

    def _updateThresholds(self, min_rating, max_rating, min_instr, max_instr,
                          min_call, max_call):
        """
        update all six threshold bounds
        @param min_rating: the minimum arit/log ratio a basic block must have
        @type min_rating: float
        @param max_rating: the maximum arit/log ratio a basic block can have
        @type max_rating: float
        @param min_instr: the minimum number of instructions a basic block must have
        @type min_instr: int
        @param max_instr: the minimum number of instructions a basic block can have
        @type max_instr: int
        @param min_call: the minimum number of calls a basic block must have
        @type min_call: int
        @param max_call: the minimum number of calls a basic block can have
        @type max_call: int
        """
        self.low_rating_threshold = max(0.0, min_rating)
        self.high_rating_threshold = min(1.0, max_rating)
        self.low_instruction_threshold = max(0, min_instr)
        if max_instr >= self.max_instruction_threshold:
            # we cap the value here and safely assume there is no block with more than 1000000 instructions
            self.high_instruction_threshold = 1000000
        else:
            self.high_instruction_threshold = max_instr
        self.low_call_threshold = max(0, min_call)
        if max_call >= self.max_call_threshold:
            # we cap the value here and safely assume there is no block with more than 1000000 instructions
            self.high_call_threshold = 1000000
        else:
            self.high_call_threshold = max_call

    def getAritlogBlocks(self, min_rating, max_rating, min_instr, max_instr, min_api, max_api, is_nonzero, \
        is_looped, is_trivially_looped):
        """
        get all blocks that are within the limits specified by the heuristic parameters.
        parameters are the same as in function "_updateThresholds" except
        param is_nonzero: defines whether zeroing instructions (like xor eax, eax) shall be counted or not.
        type is_nonzero: boolean
        param is_looped: defines whether only basic blocks in loops shall be selected
        type is_looped: boolean
        @return: a list of AritlogBasicBlock data objects, according to the parameters
        """
        self._updateThresholds(min_rating, max_rating, min_instr, max_instr,
                               min_api, max_api)
        return [
            block for block in self.aritlog_blocks
            if (self.high_rating_threshold >= block.getAritlogRating(
                is_nonzero) >= self.low_rating_threshold) and (
                    self.high_instruction_threshold >= block.num_instructions
                    >= self.low_instruction_threshold) and (
                        self.high_call_threshold >= block.num_calls_in_function
                        >= self.low_call_threshold) and
            (not is_looped or block.is_contained_in_loop) and (
                not is_trivially_looped or block.is_contained_in_trivial_loop)
        ]

    def getUnfilteredBlockCount(self):
        """
        returns the number of basic blocks that have been analyzed.
        @return: (int) number of basic blocks
        """
        return len(self.aritlog_blocks)

################################################################################
# Signature scanning
################################################################################

    def getSegmentData(self):
        """
        returns the raw bytes of the segments as stored by IDA
        @return: a list of Segment data objects.
        """
        segments = []
        for segment_ea in self.ida_proxy.Segments():
            try:
                segment = self.Segment()
                segment.start_ea = segment_ea
                segment.end_ea = self.ida_proxy.SegEnd(segment_ea)
                segment.name = self.ida_proxy.SegName(segment_ea)
                buf = ""
                for ea in helpers.Misc.lrange(
                        segment_ea, self.ida_proxy.SegEnd(segment_ea)):
                    buf += chr(self.ida_proxy.get_byte(ea))
                segment.data = buf
                segments.append(segment)
            except:
                print(
                    "[!] Tried to access invalid segment data. An error has occurred while address conversion"
                )
        return segments

    def scanCryptoPatterns(self, pattern_size=32):
        crypt_results = []
        print("[*] CryptoIdentifier: Starting crypto signature scanning.")
        time_before_matching = self.time.time()
        segments = self.getSegmentData()
        keywords = self.pm.getTokenizedSignatures(pattern_size)
        for keyword in keywords.keys():
            for segment in segments:
                crypt_results.extend([
                    self.CryptoSignatureHit(segment.start_ea + match.start(),
                                            keywords[keyword], keyword)
                    for match in self.re.finditer(self.re.escape(keyword),
                                                  segment.data)
                ])
        variable_matches = self.scanVariablePatterns()
        crypt_results.extend(variable_matches)
        print("[*] Full matching took %3.2f seconds and resulted in %d hits." %
              (self.time.time() - time_before_matching, len(crypt_results)))
        self.signature_hits = crypt_results
        return crypt_results

    def scanVariablePatterns(self):
        # the scanning code is roughly based on kyprizel's signature scan, see credtis above for more information
        crypt_results = []
        variable_signatures = self.pm.getVariableSignatures()
        for var_sig in variable_signatures.keys():
            current_seg = self.ida_proxy.FirstSeg()
            seg_end = self.ida_proxy.SegEnd(current_seg)
            while current_seg != self.ida_proxy.BAD_ADDR:
                signature_hit = self.ida_proxy.find_binary(
                    current_seg, seg_end, variable_signatures[var_sig], 16, 1)
                if signature_hit != self.ida_proxy.BAD_ADDR:
                    crypt_results.append(
                        self.CryptoSignatureHit(signature_hit, [var_sig],
                                                variable_signatures[var_sig]))
                    current_seg = signature_hit + variable_signatures[
                        var_sig].count(" ") + 1
                else:
                    current_seg = self.ida_proxy.NextSeg(seg_end)
                    if not current_seg == self.ida_proxy.BAD_ADDR:
                        seg_end = self.ida_proxy.SegEnd(current_seg)
        return crypt_results

    def getSignatureLength(self, signature_name):
        """
        returns the length for a signature, identified by its name
        @param signature_name: name for a signature, e.g. "ADLER 32"
        @type signature_name: str
        @return: (int) length of the signature.
        """
        for item in self.pm.signatures.items():
            if item[1] == signature_name:
                return len(item[0])
        return 0

    def getSignatureHits(self):
        """
        Get all signature hits that have a length of at least match_filter_factor percent
        of the signature they triggered.
        Hits are grouped by signature names.
        @return: a dictionary  with key/value entries of the following form: ("signature name", [CryptoSignatureHit])
        """
        sorted_hits = sorted(self.signature_hits)
        unified_hits = []

        previous_signature_names = []
        for hit in sorted_hits:
            hit_intersection = [
                element for element in hit.signature_names
                if element in previous_signature_names
            ]
            if len(hit_intersection) == 0:
                previous_signature_names = hit.signature_names
                unified_hits.append(self.CryptoSignatureHit(hit.start_address, hit.signature_names, \
                    hit.matched_signature))
            else:
                previous_signature_names = hit_intersection
                previous_hit = unified_hits[-1]
                if hit.start_address == previous_hit.start_address + len(
                        previous_hit.matched_signature):
                    previous_hit.matched_signature += hit.matched_signature
                    previous_hit.signature_names = hit_intersection
                else:
                    unified_hits.append(self.CryptoSignatureHit(hit.start_address, hit.signature_names, \
                        hit.matched_signature))

        filtered_hits = []
        for hit in unified_hits:
            if len(hit.matched_signature) >= max([
                    self.match_filter_factor * self.getSignatureLength(name)
                    for name in hit.signature_names
            ]):
                hit.code_refs_to = self.getXrefsToAddress(hit.start_address)
                filtered_hits.append(hit)

        grouped_hits = {}
        for hit in filtered_hits:
            for name in hit.signature_names:
                if name not in grouped_hits:
                    grouped_hits[name] = [hit]
                else:
                    grouped_hits[name].append(hit)

        return grouped_hits

    def getXrefsToAddress(self, address):
        """
        get all references to a certain address.
        These are no xrefs in IDA sense but references to the crypto signatures.
        If the signature points to an instruction, e.g. if a constant is moved to a register, the return is flagged as
        "True", meaning it is an in-code reference.
        @param address: an arbitrary address
        @type address: int
        @return: a list of tuples (int, boolean)
        """
        xrefs = []
        head_to_address = self.ida_proxy.PrevHead(address, address - 14)
        if head_to_address != 0xFFFFFFFF:
            flags = self.ida_proxy.GetFlags(head_to_address)
            if self.ida_proxy.isCode(flags):
                xrefs.append((head_to_address, True))
        for x in self.ida_proxy.XrefsTo(address):
            flags = self.ida_proxy.GetFlags(x.frm)
            if self.ida_proxy.isCode(flags):
                xrefs.append((x.frm, False))
        return xrefs
예제 #2
0
class SemanticIdentifier():
    """
    A module to analyze and explore an IDB for semantics. For a set of API names, references to these
    are identified and used for creating context and allowing tagging of them.
    """

    def __init__(self, idascope_config):
        print ("[|] loading SemanticIdentifier")
        self.os = os
        self.re = re
        self.time = time
        self.ida_proxy = IdaProxy()
        self.FunctionContext = FunctionContext
        self.FunctionContextFilter = FunctionContextFilter
        self.CallContext = CallContext
        self.ParameterContext = ParameterContext
        # fields
        self.semantics = {}
        self.active_semantics = {}
        self.renaming_seperator = "_"
        self.semantic_groups = []
        self.semantic_definitions = []
        self.real_api_names = {}
        self.last_scan_result = {}
        self.idascope_config = idascope_config
        self._getRealApiNames()
        self._loadSemantics(self.idascope_config)
        return

    def _cbEnumImports(self, addr, name, ordinal):
        if name:
            self.real_api_names[name] = self.ida_proxy.Name(addr)
        return True

    def _getRealApiNames(self):
        num_imports = self.ida_proxy.get_import_module_qty()
        for i in xrange(0, num_imports):
            self.ida_proxy.enum_import_names(i, self._cbEnumImports)

    def lookupRealApiName(self, api_name):
        if api_name in self.real_api_names:
            return self.real_api_names[api_name]
        else:
            return api_name

    def lookupDisplayApiName(self, real_api_name):
        """ returns the key by given value of self.real_api_names (basically inverted dictionary)
        """
        name = real_api_name
        for display_name in self.real_api_names:
            if real_api_name == self.real_api_names[display_name] \
                    and display_name in self.real_api_names[display_name]:
                name = display_name
        return name

    def _loadSemantics(self, config):
        """
        Loads a semantic configuration file and collects all definitions from it.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        for filename in [fn for fn in self.os.listdir(config.semantics_folder) if fn.endswith(".json")]:
            loaded_file = self._loadSemanticsFile(config.semantics_folder + self.os.sep + filename)
            self.semantics[loaded_file["name"]] = loaded_file
        if config.inspection_default_semantics in self.semantics:
            self._setSemantics(config.inspection_default_semantics)
        elif len(self.semantics) > 0:
            self._setSemantics(sorted(self.semantics.keys())[0])
        else:
            self._setSemantics("")
        return

    def _loadSemanticsFile(self, semantics_filename):
        """
        Loads a semantic configuration file and collects all definitions from it.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        semantics_file = open(semantics_filename, "r")
        semantics = semantics_file.read()
        return json.loads(semantics, object_hook=JsonHelper.decode_dict)

    def _setSemantics(self, semantics_entry):
        semantics_content = {}
        if semantics_entry in self.semantics:
            semantics_content = self.semantics[semantics_entry]
            self.renaming_seperator = semantics_content["renaming_seperator"]
            self.semantic_groups = semantics_content["semantic_groups"]
            self.semantic_definitions = semantics_content["semantic_definitions"]
            self.active_semantics = semantics_content
        else:
            self.renaming_seperator = "_"
            self.semantic_groups = []
            self.semantic_definitions = []
            self.active_semantics = {"name": "none"}
        self.scanByReferences()

    def getSemanticsNames(self):
        return sorted(self.semantics.keys())

    def getActiveSemanticsName(self):
        return self.active_semantics["name"]

    def calculateNumberOfBasicBlocksForFunctionAddress(self, function_address):
        """
        Calculates the number of basic blocks for a given function by walking its FlowChart.
        @param function_address: function address to calculate the block count for
        @type function_address: int
        """
        number_of_blocks = 0
        try:
            func_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(function_address))
            for block in func_chart:
                number_of_blocks += 1
        except:
            pass
        return number_of_blocks

    def getNumberOfBasicBlocksForFunctionAddress(self, address):
        """
        returns the number of basic blocks for the function containing the queried address,
        based on the value stored in the last scan result.

        If the number of basic blocks for this function has never been calculated, zero is returned.
        @param function_address: function address to get the block count for
        @type function_address: int
        @return: (int) The number of blocks in th e function
        """
        number_of_blocks = 0
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            number_of_blocks = self.last_scan_result[function_address].number_of_basic_blocks
        return number_of_blocks

    def scan(self):
        """
        Scan the whole IDB with all available techniques.
        """
        self.scanByReferences()
        self.scanDeep()

    def scanByReferences(self):
        """
        Scan by references to API names, based on the definitions loaded from the config file.
        This is highly efficient because we only touch places in the IDB that actually have references
        to our API names of interest.
        """
        print ("  [/] SemanticIdentifier: Starting (fast) scan by references of function semantics.")
        time_before = self.time.time()
        self.last_scan_result = {}
        for semantic_tag in self.semantic_definitions:
            for api_name in semantic_tag["api_names"]:
                real_api_name = self.lookupRealApiName(api_name)
                api_address = self.ida_proxy.LocByName(real_api_name)
                for ref in self._getAllRefsTo(api_address):
                    function_ctx = self._getFunctionContext(ref)
                    function_ctx.has_tags = True
                    call_ctx = self.CallContext()
                    call_ctx.called_function_name = api_name
                    call_ctx.real_called_function_name = real_api_name
                    call_ctx.address_of_call = ref
                    call_ctx.called_address = api_address
                    call_ctx.tag = semantic_tag["tag"]
                    call_ctx.group = semantic_tag["group"]
                    call_ctx.parameter_contexts = self._resolveApiCall(call_ctx)
                    function_ctx.call_contexts.append(call_ctx)
        print ("  [\\] Analysis took %3.2f seconds." % (self.time.time() - time_before))

    def _getAllRefsTo(self, addr):
        code_ref_addrs = [ref for ref in self.ida_proxy.CodeRefsTo(addr, 0)]
        data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsTo(addr)]
        return iter(set(code_ref_addrs).union(set(data_ref_addrs)))

    def _getNumRefsTo(self, addr):
        return sum([1 for ref in self._getAllRefsTo(addr)])

    def _getAllRefsFrom(self, addr, code_only=False):
        code_ref_addrs = [ref for ref in self.ida_proxy.CodeRefsFrom(addr, 0)]
        data_ref_addrs = []
        if code_only:
            # only consider data references that lead to a call near/far (likely imports)
            data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsFrom(addr) if \
                self.ida_proxy.GetFlags(ref) & (self.ida_proxy.FL_CN | self.ida_proxy.FL_CF)]
        else:
            data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsFrom(addr)]
        return iter(set(code_ref_addrs).union(set(data_ref_addrs)))

    def _getFunctionContext(self, addr):
        """
        Create or return an existing FunctionContext for the given address in the current scan result.
        @param func_addr: address to create a FunctionContext for
        @type func_addr: int
        @return: (FunctionContext) A reference to the corresponding function context
        """
        function_ctx = None
        function_address = self.ida_proxy.LocByName(self.ida_proxy.GetFunctionName(addr))
        if function_address not in self.last_scan_result.keys():
            function_ctx = self.FunctionContext()
            function_ctx.function_address = function_address
            function_ctx.function_name = self.ida_proxy.GetFunctionName(function_address)
            function_ctx.has_dummy_name = (self.ida_proxy.GetFlags(function_address) & \
                self.ida_proxy.FF_LABL) > 0
            self.last_scan_result[function_ctx.function_address] = function_ctx
        else:
            function_ctx = self.last_scan_result[function_address]
        return function_ctx

    def scanDeep(self):
        """
        Perform a full enumeration of all instructions,
        gathering information like number of instructions, number of basic blocks,
        references to and from functions etc.
        """
        print ("  [/] SemanticIdentifier: Starting deep scan of function semantics.")
        time_before = self.time.time()
        for function_ea in self.ida_proxy.Functions():
            function_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(function_ea))
            num_blocks = 0
            num_instructions = 0
            xrefs_from = []
            calls_from = []
            function_ctx = self._getFunctionContext(function_ea)
            for block in function_chart:
                num_blocks += 1
                for instruction in self.ida_proxy.Heads(block.startEA, block.endEA):
                    num_instructions += 1
                    if self.ida_proxy.isCode(self.ida_proxy.GetFlags(instruction)):
                        for ref in self._getAllRefsFrom(instruction):
                            if self.ida_proxy.GetMnem(instruction) == "call":
                                calls_from.append(ref)
                            xrefs_from.append(ref)
            function_ctx.calls_from.update(calls_from)
            function_ctx.number_of_xrefs_to = self._getNumRefsTo(function_ea)
            function_ctx.xrefs_from.update(xrefs_from)
            function_ctx.number_of_xrefs_from = len(xrefs_from)
            function_ctx.number_of_basic_blocks = num_blocks
            function_ctx.number_of_instructions = num_instructions
        print ("  [\\] Analysis took %3.2f seconds." % (self.time.time() - time_before))

    def getFunctionAddressForAddress(self, address):
        """
        Get a function address containing the queried address.
        @param address: address to check the function address for
        @type address: int
        @return: (int) The start address of the function containing this address
        """
        return self.ida_proxy.LocByName(self.ida_proxy.GetFunctionName(address))

    def calculateNumberOfFunctions(self):
        """
        Calculate the number of functions in all segments.
        @return: (int) the number of functions found.
        """
        number_of_functions = 0
        for seg_ea in self.ida_proxy.Segments():
            for function_ea in self.ida_proxy.Functions(self.ida_proxy.SegStart(seg_ea), self.ida_proxy.SegEnd(seg_ea)):
                number_of_functions += 1
        return number_of_functions

    def calculateNumberOfTaggedFunctions(self):
        """
        Calculate the number of functions in all segments that have been tagged.
        @return: (int) the number of functions found.
        """
        return len(self.getFunctionAddresses(self.createFunctionContextFilter()))

    def getFunctionAddresses(self, context_filter):
        """
        Get all function address that have been covered by the last scanning.
        @param dummy_only: only return functions with dummy names
        @type dummy_only: bool
        @param tag_only: only return tag functions
        @type tag_only: bool
        @return: (list of int) The addresses of covered functions.
        """
        all_addresses = self.last_scan_result.keys()
        filtered_addresses = []
        if context_filter.display_all:
            filtered_addresses = all_addresses
        elif context_filter.display_tags:
            for address in all_addresses:
                enabled_tags = [tag[0] for tag in context_filter.enabled_tags]
                if len(set(self.last_scan_result[address].getTags()) & set(enabled_tags)) > 0:
                    filtered_addresses.append(address)
        elif context_filter.display_groups:
            for address in all_addresses:
                enabled_groups = [group[0] for group in context_filter.enabled_groups]
                if len(set(self.last_scan_result[address].getGroups()) & set(enabled_groups)) > 0:
                    filtered_addresses.append(address)
        # filter additionals
        if context_filter.isDisplayTagOnly():
            filtered_addresses = [addr for addr in filtered_addresses if self.last_scan_result[addr].has_tags]
        if context_filter.isDisplayDummyOnly():
            filtered_addresses = [addr for addr in filtered_addresses if self.last_scan_result[addr].has_dummy_name]
        return filtered_addresses

    def getTags(self):
        """
        Get all the tags that have been covered by the last scanning.
        @return (list of str) The tags found.
        """
        tags = []
        for function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[function_address].call_contexts:
                if call_ctx.tag not in tags:
                    tags.append(call_ctx.tag)
        return tags

    def getGroups(self):
        """
        Get all the groups that have been covered by tags in the last scanning.
        @return (list of str) The groups found.
        """
        tag_to_group_mapping = self._createTagToGroupMapping()
        groups = []
        for function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[function_address].call_contexts:
                if tag_to_group_mapping[call_ctx.tag] not in groups:
                    groups.append(tag_to_group_mapping[call_ctx.tag])
        return groups

    def _createTagToGroupMapping(self):
        mapping = {}
        for definition in self.semantic_definitions:
            mapping[definition["tag"]] = definition["group"]
        return mapping

    def getTagsForFunctionAddress(self, address):
        """
        Get all tags found for the function containing the queried address.
        @param address: address in the target function
        @type address: int
        @return: (list of str) The tags for the function containing the queried address
        """
        tags = []
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[function_address].call_contexts:
                if call_ctx.tag not in tags:
                    tags.append(call_ctx.tag)
        return tags

    def getFieldCountForFunctionAddress(self, query, address):
        """
        Get the number of occurrences for a certain field for the function containing the queried address.
        @param query: a tuple (type, name), where type is additional, tag, or group and name the field being queried.
        @type query: tuple
        @param address: address in the target function
        @type address: int
        @return: (int) The number of occurrences for this tag in the function
        """
        function_address = self.getFunctionAddressForAddress(address)
        return self.last_scan_result[function_address].getCountForField(query)

    def getTaggedApisForFunctionAddress(self, address):
        """
        Get all call contexts for the function containing the queried address.
        @param address: address in the target function
        @type address: int
        @return: (list of CallContext data objects) The call contexts identified by the scanning of this function
        """
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            all_call_ctx = self.last_scan_result[function_address].call_contexts
            return [call_ctx for call_ctx in all_call_ctx if call_ctx.tag != ""]

    def getAddressTagPairsOrderedByFunction(self):
        """
        Get all call contexts for all functions
        @return: a dictionary with key/value entries of the following form: (function_address,
                 dict((call_address, tag)))
        """
        functions_and_tags = {}
        for function in self.getIdentifiedFunctionAddresses():
            call_contexts = self.getTaggedApisForFunctionAddress(function)
            if function not in functions_and_tags.keys():
                functions_and_tags[function] = {}
            for call_ctx in call_contexts:
                functions_and_tags[function][call_ctx.address_of_call] = call_ctx.tag
        return functions_and_tags

    def getFunctionsToRename(self):
        """
        Get all functions that can be renamed according to the last scan result. Only functions with the standard
        IDA name I{sub_[0-9A-F]+} will be considered for renaming.
        @return: a list of dictionaries, each consisting of three tuples: ("old_function_name", str), \
                 ("new_function_name", str), ("function_address", int)
        """
        functions_to_rename = []
        for function_address_to_tag in self.last_scan_result.keys():
            new_function_name = self.last_scan_result[function_address_to_tag].function_name
            # has the function still a dummy name?
            if self.ida_proxy.GetFlags(function_address_to_tag) & self.ida_proxy.FF_LABL > 0:
                tags_for_function = self.getTagsForFunctionAddress(function_address_to_tag)
                for tag in sorted(tags_for_function, reverse=True):
                    if tag not in new_function_name:
                        new_function_name = tag + self.renaming_seperator + new_function_name
                functions_to_rename.append({"old_function_name": \
                    self.last_scan_result[function_address_to_tag].function_name, "new_function_name": \
                    new_function_name, "function_address": function_address_to_tag})
        return functions_to_rename

    def renameFunctions(self):
        """
        Perform the renaming of functions according to the last scan result.
        """
        for function in self.getFunctionsToRename():
            if function["old_function_name"] == self.ida_proxy.GetFunctionName(function["function_address"]):
                self.ida_proxy.MakeNameEx(function["function_address"], function["new_function_name"], \
                    self.ida_proxy.SN_NOWARN)

    def renamePotentialWrapperFunctions(self):
        """
        contributed by Branko Spasojevic.
        """
        num_wrappers_renamed = 0
        for seg_ea in self.ida_proxy.Segments():
            for func_ea in self.ida_proxy.Functions(self.ida_proxy.SegStart(seg_ea), self.ida_proxy.SegEnd(seg_ea)):
                if (self.ida_proxy.GetFlags(func_ea) & 0x8000) != 0:
                    nr_calls, w_name = self._checkWrapperHeuristics(func_ea)
                    if nr_calls == 1 and len(w_name) > 0:
                        rval = False
                        name_suffix = 0
                        while rval == False:
                            if name_suffix > 40:
                                print("[!] Potentially more than 50 wrappers for function %s, " \
                                    "please report this IDB ;)" % w_name)
                                break
                            demangled_name = self.ida_proxy.Demangle(w_name, self.ida_proxy.GetLongPrm(self.ida_proxy.INF_SHORT_DN))
                            if demangled_name != None and demangled_name != w_name:
                                f_name = w_name + '_w' + str(name_suffix)
                            elif name_suffix > 0:
                                f_name = w_name + '_w' + str(name_suffix)
                            else:
                                f_name = w_name + '_w0'
                            name_suffix += 1
                            rval = self.ida_proxy.MakeNameEx(func_ea, f_name, \
                                self.ida_proxy.SN_NOCHECK | self.ida_proxy.SN_NOWARN)
                        if rval == True:
                            print("[+] Identified and renamed potential wrapper @ [%08x] to [%s]" % \
                                (func_ea, f_name))
                            num_wrappers_renamed += 1
        print("[+] Renamed %d functions with their potentially wrapped name." % num_wrappers_renamed)

    def _checkWrapperHeuristics(self, func_ea):
        """
        Helps renamePotentialWrapperFunctions() to decide whether the function analyzed is a wrapper or not.
        """
        nr_calls = 0
        w_name = ""
        # Heuristic: wrappers are likely short
        func_end = self.ida_proxy.GetFunctionAttr(func_ea, self.ida_proxy.FUNCATTR_END)
        if (func_end - func_ea) > 0 and (func_end - func_ea) < 0x40:
            return (0, "")
        # Heuristic: wrappers shall only have a single reference, ideally to a library function.
        for i_ea in self.ida_proxy.FuncItems(func_ea):
            # long jumps don't occur in wrappers considered by this code.
            if self.ida_proxy.GetMnem(i_ea) == 'jmp' \
                and (func_ea > self.ida_proxy.GetOperandValue(i_ea,0) \
                    or func_end < self.ida_proxy.GetOperandValue(i_ea,0)):
                   nr_calls += 2
            # checks if call is not memory reference
            if self.ida_proxy.GetMnem(i_ea) == 'call':
                nr_calls += 1
                if self.ida_proxy.GetOpType(i_ea,0) != 2 \
                    and self.ida_proxy.GetOpType(i_ea,0) != 6 \
                        and self.ida_proxy.GetOpType(i_ea,0) != 7:
                    nr_calls += 2
                if nr_calls > 1:
                    break
                call_dst = list(self.ida_proxy.CodeRefsFrom(i_ea, 0))
                if len(call_dst) == 0:
                    continue
                call_dst = call_dst[0]
                if (self.ida_proxy.GetFunctionFlags(call_dst) & self.ida_proxy.FUNC_LIB) != 0 or \
                    (self.ida_proxy.GetFlags(func_ea) & self.ida_proxy.FF_LABL) == 0:
                    w_name = self.ida_proxy.Name(call_dst)
        return (nr_calls, w_name)


    def getParametersForCallAddress(self, call_address):
        """
        Get the parameters for the given address of a function call.
        @param call_address: address of the target call to inspect
        @type call_address: int
        @return: a list of ParameterContext data objects.
        """
        target_function_address = self.ida_proxy.LocByName(self.ida_proxy.GetFunctionName(call_address))
        all_tagged_apis_in_function = self.getTaggedApisForFunctionAddress(target_function_address)
        for api in all_tagged_apis_in_function:
            if api.address_of_call == call_address:
                return self._resolveApiCall(api)
        return []

    def _resolveApiCall(self, call_context):
        """
        Resolve the parameters for an API calls based on a call context for this API call.
        @param call_context: the call context to get the parameter information for
        @type call_context: a CallContext data object
        @return: a list of ParameterContext data objects.
        """
        resolved_api_parameters = []
        api_signature = self._getApiSignature(call_context.real_called_function_name)
        push_addresses = self._getPushAddressesBeforeTargetAddress(call_context.address_of_call)
        resolved_api_parameters = self._matchPushAddressesToSignature(push_addresses, api_signature)
        return resolved_api_parameters

    def _matchPushAddressesToSignature(self, push_addresses, api_signature):
        """
        Combine the results of I{_getPushAddressesBeforeTargetAddress} and I{_getApiSignature} in order to
        produce a list of ParameterContext data objects.
        @param push_addresses: the identified push addresses before a function call that shall be matched to a function
                               signature
        @type push_addresses: a list of int
        @param api_signature: information about a function definition with
                              parameter names, types, and so on.
        @type api_signature: a dictionary with the layout as returned by I{_getApiSignature}
        @return: a list of ParameterContext data objects.
        """
        matched_parameters = []
        # TODO:
        # upgrade this feature with data flow analysis to resolve parameters with higher precision
        api_num_params = len(api_signature["parameters"])
        push_addresses = push_addresses[-api_num_params:]
        # TODO:
        # There might be the case where we identify less pushed parameters than required by the function
        # signature. Thus we calculate a "parameter discrepancy" that we use to adjust our enumeration index
        # so that the last n parameters get matched correctly. This is a temporary fix and might be solved later on.
        parameter_discrepancy = len(push_addresses) - api_num_params
        for index, param in enumerate(api_signature["parameters"], start=parameter_discrepancy):
            param_ctx = self.ParameterContext()
            param_ctx.parameter_type = param["type"]
            param_ctx.parameter_name = param["name"]
            if (parameter_discrepancy != 0) and (index < 0):
                param_ctx.valid = False
            else:
                param_ctx.push_address = push_addresses[index]
                param_ctx.ida_operand_type = self.ida_proxy.GetOpType(push_addresses[index], 0)
                param_ctx.ida_operand_value = self.ida_proxy.GetOperandValue(push_addresses[index], 0)
                param_ctx.value = param_ctx.ida_operand_value
            matched_parameters.append(param_ctx)
        return matched_parameters

    def _getApiSignature(self, api_name):
        """
        Get the signature for a function by using IDA's I{GetType()}. The string is then parsed with a Regex and
        returned as a dictionary.
        @param api_name: name of the API / function to get type information for
        @type api_name: str
        @return: a dictionary with key/value entries of the following form: ("return_type", str),
                 ("parameters", [dict(("type", str), ("name", str))])
        """
        api_signature = {"api_name": api_name, "parameters": []}
        api_location = self.ida_proxy.LocByName(api_name)
        type_def = self.ida_proxy.GetType(api_location)
        function_signature_regex = r"(?P<return_type>[\w\s\*]+)\((?P<parameters>[,\.\*\w\s]*)\)"
        result = self.re.match(function_signature_regex, type_def)
        if result is not None:
            api_signature["return_type"] = result.group("return_type")
            if len(result.group("parameters")) > 0:
                for parameter in result.group("parameters").split(","):
                    type_and_name = {}
                    type_and_name["type"] = parameter[:parameter.rfind(" ")].strip()
                    type_and_name["name"] = parameter[parameter.rfind(" "):].strip()
                    api_signature["parameters"].append(type_and_name)
        else:
            print ("[-] SemanticIdentifier._getApiSignature: No API/function signature for \"%s\" @ 0x%x available. " \
            + "(non-critical)") % (api_name, api_location)
        # TODO:
        # here should be a check for the calling convention
        # currently, list list is simply reversed to match the order parameters are pushed to the stack
        api_signature["parameters"].reverse()
        return api_signature

    def _getPushAddressesBeforeTargetAddress(self, address):
        """
        Get the addresses of all push instructions in the basic block preceding the given address.
        @param address: address to get the push addresses for.
        @type address: int
        @return: a list of int
        """
        push_addresses = []
        function_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(address))
        for block in function_chart:
            if block.startEA <= address < block.endEA:
                for instruction_addr in self.ida_proxy.Heads(block.startEA, block.endEA):
                    if self.ida_proxy.GetMnem(instruction_addr) == "push":
                        push_addresses.append(instruction_addr)
                    if instruction_addr >= address:
                        break
        return push_addresses

    def createFunctionGraph(self, func_address):
        graph = {"root": func_address, "nodes": {}}
        unexplored = set()
        if func_address in self.last_scan_result.keys():
            graph["nodes"][func_address] = self.last_scan_result[func_address].calls_from
            unexplored = set(self.last_scan_result[func_address].calls_from)
            while len(unexplored) > 0:
                current_function = unexplored.pop()
                if current_function in graph["nodes"].keys() or current_function not in self.last_scan_result.keys():
                    continue
                else:
                    graph["nodes"][current_function] = self.last_scan_result[current_function].calls_from
                    new_functions = \
                        set(self.last_scan_result[current_function].calls_from).difference(set(graph["nodes"].keys()))
                    unexplored.update(new_functions)
        return graph

    def createFunctionContextFilter(self):
        """
        Create a function filter, containing only those tags/groups that have been identified within the last scan.
        """
        context_filter = self.FunctionContextFilter()
        context_filter.tags = sorted([(tag, tag, tag) for tag in self.getTags()])
        context_filter.enabled_tags = context_filter.tags
        context_filter.groups = sorted([(group, group, group) for group in self.getGroups()])
        context_filter.enabled_groups = context_filter.groups
        return context_filter

    def getLastScanResult(self):
        """
        Get the last scan result as retrieved by I{scanByReferences}.
        @return: a dictionary with key/value entries of the following form: (function_address, FunctionContext)
        """
        return self.last_scan_result

    def printLastScanResult(self):
        """
        nicely print the last scan result (mostly used for debugging)
        """
        for function_address in self.last_scan_result.keys():
            print ("0x%x - %s -> ") % (function_address, self.ida_proxy.GetFunctionName(function_address)) \
                + ", ".join(self.getTagsForFunctionAddress(function_address))
            for call_ctx in self.last_scan_result[function_address].call_contexts:
                print ("    0x%x - %s (%s)") % (call_ctx.address_of_call, call_ctx.called_function_name, call_ctx.tag)
예제 #3
0
class CryptoIdentifier():
    """
    This class contains the logic to perform Crypto identification.
    Two techniques are currently supported:
    1. A heuristic approach that identifies functions and basic blocks
    based on the ratio of arithmetic/logic instructions to all instructions
    2. A signature-based approach, using the signatures defined in PatternManager
    """

    def __init__(self):
        self.name = "CryptoIdentifier"
        print ("[*] loading CryptoIdentifier")
        self.time = time
        self.re = re
        self.GraphHelper = GraphHelper
        self.CryptoSignatureHit = CryptoSignatureHit
        self.AritlogBasicBlock = AritlogBasicBlock
        self.Segment = Segment
        self.pm = PatternManager()
        self.low_rating_threshold = 0.4
        self.high_rating_threshold = 1.0
        self.low_instruction_threshold = 8
        self.high_instruction_threshold = 100
        # if the threshold is set to this value, it is automatically expanded to infinite.
        self.max_instruction_threshold = 100
        self.low_call_threshold = 0
        self.high_call_threshold = 1
        # if the threshold is set to this value, it is automatically expanded to infinite.
        self.max_call_threshold = 10
        # if at least this fraction of a signature's length' has been identified
        # consecutively, the location is marked as a signature hit.        
        self.match_filter_factor = 0.5
        self.aritlog_blocks = []
        self.signature_hits = []
        self.ida_proxy = IdaProxy()
        return

    def scan(self):
        """
        Scan the whole IDB with all available techniques.
        """
        self.scanAritlog()
        self.scanCryptoPatterns()
        
################################################################################
# Aritlog scanning
################################################################################

    def scanAritlog(self):
        """
        scan with the arithmetic/logic heuristic
        @return: a list of AritLogBasicBlock data objects that fulfill the parameters as specified
        """
        print ("[*] CryptoIdentifier: Starting aritlog heuristic analysis.")
        self.aritlog_blocks = []
        time_before = self.time.time()
        for function_ea in self.ida_proxy.Functions():
            function_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(function_ea))
            calls_in_function = 0
            function_blocks = []
            function_dgraph = {}
            blocks_in_loops = set()
            for current_block in function_chart:
                block = self.AritlogBasicBlock(current_block.startEA, current_block.endEA)
                for instruction in self.ida_proxy.Heads(block.start_ea, block.end_ea):
                    if self.ida_proxy.isCode(self.ida_proxy.GetFlags(instruction)):
                        mnemonic = self.ida_proxy.GetMnem(instruction)
                        has_identical_operands = self.ida_proxy.GetOperandValue(instruction, 0) == \
                            self.ida_proxy.GetOperandValue(instruction, 1)
                        block.updateInstructionCount(mnemonic, has_identical_operands)
                        if mnemonic == "call":
                            calls_in_function += 1
                function_blocks.append(block)
                # prepare graph for Tarjan's algorithm
                succeeding_blocks = [succ.startEA for succ in current_block.succs()]
                function_dgraph[current_block.startEA] = succeeding_blocks
                # add trivial loops
                if current_block.startEA in succeeding_blocks:
                    block.is_contained_in_trivial_loop = True
                    blocks_in_loops.update([current_block.startEA])
            # perform Tarjan's algorithm to identify strongly connected components (= loops) in the function graph
            graph_helper = self.GraphHelper()
            strongly_connected = graph_helper.calculateStronglyConnectedComponents(function_dgraph)
            non_trivial_loops = [component for component in strongly_connected if len(component) > 1]
            for component in non_trivial_loops:
                for block in component:
                    blocks_in_loops.update([block])
            for block in function_blocks:
                if block.start_ea in blocks_in_loops:
                    block.is_contained_in_loop = True
                block.num_calls_in_function = calls_in_function
            self.aritlog_blocks.extend(function_blocks)
        print ("[*] Heuristics analysis took %3.2f seconds." % (self.time.time() - time_before))

        return self.getAritlogBlocks(self.low_rating_threshold, self.high_rating_threshold,
            self.low_instruction_threshold, self.high_instruction_threshold,
            self.low_call_threshold, self.high_call_threshold,
            False, False, False)

    def _updateThresholds(self, min_rating, max_rating, min_instr, max_instr, min_call, max_call):
        """
        update all six threshold bounds
        @param min_rating: the minimum arit/log ratio a basic block must have
        @type min_rating: float
        @param max_rating: the maximum arit/log ratio a basic block can have
        @type max_rating: float
        @param min_instr: the minimum number of instructions a basic block must have
        @type min_instr: int
        @param max_instr: the minimum number of instructions a basic block can have
        @type max_instr: int
        @param min_call: the minimum number of calls a basic block must have
        @type min_call: int
        @param max_call: the minimum number of calls a basic block can have
        @type max_call: int
        """
        self.low_rating_threshold = max(0.0, min_rating)
        self.high_rating_threshold = min(1.0, max_rating)
        self.low_instruction_threshold = max(0, min_instr)
        if max_instr >= self.max_instruction_threshold:
            # we cap the value here and safely assume there is no block with more than 1000000 instructions
            self.high_instruction_threshold = 1000000
        else:
            self.high_instruction_threshold = max_instr
        self.low_call_threshold = max(0, min_call)
        if max_call >= self.max_call_threshold:
            # we cap the value here and safely assume there is no block with more than 1000000 instructions
            self.high_call_threshold = 1000000
        else:
            self.high_call_threshold = max_call

    def getAritlogBlocks(self, min_rating, max_rating, min_instr, max_instr, min_api, max_api, is_nonzero, \
        is_looped, is_trivially_looped):
        """
        get all blocks that are within the limits specified by the heuristic parameters.
        parameters are the same as in function "_updateThresholds" except
        param is_nonzero: defines whether zeroing instructions (like xor eax, eax) shall be counted or not.
        type is_nonzero: boolean
        param is_looped: defines whether only basic blocks in loops shall be selected
        type is_looped: boolean
        @return: a list of AritlogBasicBlock data objects, according to the parameters
        """
        self._updateThresholds(min_rating, max_rating, min_instr, max_instr, min_api, max_api)
        return [block for block in self.aritlog_blocks if
            (self.high_rating_threshold >= block.getAritlogRating(is_nonzero) >= self.low_rating_threshold) and
            (self.high_instruction_threshold >= block.num_instructions >= self.low_instruction_threshold) and
            (self.high_call_threshold >= block.num_calls_in_function >= self.low_call_threshold) and
            (not is_looped or block.is_contained_in_loop) and
            (not is_trivially_looped or block.is_contained_in_trivial_loop)]

    def getUnfilteredBlockCount(self):
        """
        returns the number of basic blocks that have been analyzed.
        @return: (int) number of basic blocks
        """
        return len(self.aritlog_blocks)

################################################################################
# Signature scanning
################################################################################

    def getSegmentData(self):
        """
        returns the raw bytes of the segments as stored by IDA
        @return: a list of Segment data objects.
        """
        segments = []
        for segment_ea in self.ida_proxy.Segments():
            try:
                segment = self.Segment()
                segment.start_ea = segment_ea
                segment.end_ea = self.ida_proxy.SegEnd(segment_ea)
                segment.name = self.ida_proxy.SegName(segment_ea)
                buf = ""
                for ea in helpers.Misc.lrange(segment_ea, self.ida_proxy.SegEnd(segment_ea)):
                    buf += chr(self.ida_proxy.get_byte(ea))
                segment.data = buf
                segments.append(segment)
            except:
                print ("[!] Tried to access invalid segment data. An error has occurred while address conversion")
        return segments

    def scanCryptoPatterns(self, pattern_size=32):
        crypt_results = []
        print ("[*] CryptoIdentifier: Starting crypto signature scanning.")
        time_before_matching = self.time.time()
        segments = self.getSegmentData()
        keywords = self.pm.getTokenizedSignatures(pattern_size)
        for keyword in keywords.keys():
            for segment in segments:
                crypt_results.extend([self.CryptoSignatureHit(segment.start_ea + match.start(), keywords[keyword], keyword) for match in self.re.finditer(self.re.escape(keyword), segment.data)])
        variable_matches = self.scanVariablePatterns()
        crypt_results.extend(variable_matches)
        print ("[*] Full matching took %3.2f seconds and resulted in %d hits." % (self.time.time() - time_before_matching, len(crypt_results)))
        self.signature_hits = crypt_results
        return crypt_results

    def scanVariablePatterns(self):
        # the scanning code is roughly based on kyprizel's signature scan, see credtis above for more information
        crypt_results = []
        variable_signatures = self.pm.getVariableSignatures()
        for var_sig in variable_signatures.keys():
            current_seg = self.ida_proxy.FirstSeg()
            seg_end = self.ida_proxy.SegEnd(current_seg)
            while current_seg != self.ida_proxy.BAD_ADDR:
                signature_hit = self.ida_proxy.find_binary(current_seg, seg_end, variable_signatures[var_sig], 16, 1)
                if signature_hit != self.ida_proxy.BAD_ADDR:
                    crypt_results.append(self.CryptoSignatureHit(signature_hit, [var_sig], variable_signatures[var_sig]))
                    current_seg = signature_hit + variable_signatures[var_sig].count(" ") + 1
                else:
                    current_seg = self.ida_proxy.NextSeg(seg_end)
                    if not current_seg == self.ida_proxy.BAD_ADDR:
                        seg_end = self.ida_proxy.SegEnd(current_seg)
        return crypt_results

    def getSignatureLength(self, signature_name):
        """
        returns the length for a signature, identified by its name
        @param signature_name: name for a signature, e.g. "ADLER 32"
        @type signature_name: str
        @return: (int) length of the signature.
        """
        for item in self.pm.signatures.items():
            if item[1] == signature_name:
                return len(item[0])
        return 0

    def getSignatureHits(self):
        """
        Get all signature hits that have a length of at least match_filter_factor percent
        of the signature they triggered.
        Hits are grouped by signature names.
        @return: a dictionary  with key/value entries of the following form: ("signature name", [CryptoSignatureHit])
        """
        sorted_hits = sorted(self.signature_hits)
        unified_hits = []

        previous_signature_names = []
        for hit in sorted_hits:
            hit_intersection = [element for element in hit.signature_names if element in previous_signature_names]
            if len(hit_intersection) == 0:
                previous_signature_names = hit.signature_names
                unified_hits.append(self.CryptoSignatureHit(hit.start_address, hit.signature_names, \
                    hit.matched_signature))
            else:
                previous_signature_names = hit_intersection
                previous_hit = unified_hits[-1]
                if hit.start_address == previous_hit.start_address + len(previous_hit.matched_signature):
                    previous_hit.matched_signature += hit.matched_signature
                    previous_hit.signature_names = hit_intersection
                else:
                    unified_hits.append(self.CryptoSignatureHit(hit.start_address, hit.signature_names, \
                        hit.matched_signature))
                    
        filtered_hits = []
        for hit in unified_hits:
            if len(hit.matched_signature) >= max([self.match_filter_factor * self.getSignatureLength(name) for name in hit.signature_names]):
                hit.code_refs_to = self.getXrefsToAddress(hit.start_address)
                filtered_hits.append(hit)

        grouped_hits = {}
        for hit in filtered_hits:
            for name in hit.signature_names:
                if name not in grouped_hits:
                    grouped_hits[name] = [hit]
                else:
                    grouped_hits[name].append(hit)

        return grouped_hits
    
    def getXrefsToAddress(self, address):
        """
        get all references to a certain address.
        These are no xrefs in IDA sense but references to the crypto signatures.
        If the signature points to an instruction, e.g. if a constant is moved to a register, the return is flagged as
        "True", meaning it is an in-code reference.
        @param address: an arbitrary address
        @type address: int
        @return: a list of tuples (int, boolean)
        """
        xrefs = []
        head_to_address = self.ida_proxy.PrevHead(address, address - 14)
        if head_to_address != 0xFFFFFFFF:
            flags = self.ida_proxy.GetFlags(head_to_address)
            if self.ida_proxy.isCode(flags):
                xrefs.append((head_to_address, True))
        for x in  self.ida_proxy.XrefsTo(address):
            flags = self.ida_proxy.GetFlags(x.frm)
            if self.ida_proxy.isCode(flags):
                xrefs.append((x.frm, False))
        return xrefs
예제 #4
0
class DocumentationHelper():
    """
    This class handles instruction coloring.
    """

    # data layout of color maps
    layout_color_map = {"tag": {"base_color": 0x112233, "highlight_color": 0x445566}}

    def __init__(self, idascope_config):
        print ("[|] loading DocumentationHelper")
        self.ida_proxy = IdaProxy()
        # default colors are grey / light red / red
        self.default_neutral_color = 0xCCCCCC
        self.default_base_color = 0xB3B3FF
        self.default_highlight_color = 0x3333FF
        self.color_state = "unknown"
        self.idascope_config = idascope_config
        self._loadConfig(self.idascope_config.semantics_file)
        return

    def _loadConfig(self, config_filename):
        """
        Loads a semantic configuration file and generates a color map from the contained information.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        config_file = open(config_filename, "r")
        config = config_file.read()
        parsed_config = json.loads(config, object_hook=JsonHelper.decode_dict)
        self.default_neutral_color = int(parsed_config["default_neutral_color"], 16)
        self.default_base_color = int(parsed_config["default_base_color"], 16)
        self.default_highlight_color = int(parsed_config["default_highlight_color"], 16)
        self.color_map = self._generateColorMapFromDefinitions(parsed_config)
        return

    def _generateColorMapFromDefinitions(self, config):
        """
        Internal function to generate a color map from a semantic definitions config file.
        @param definitions: the defintions part of a semantic definitions config file.
        @type definitions: dict
        @return: a dictionary of a color map, see I{layout_color_map} for a reference
        """
        color_map = {}
        for definition in config["semantic_definitions"]:
            # convert text representation of color codes to numbers
            group_colors = self._getColorsForGroup(definition["group"], config)
            color_map[definition["tag"]] = {"base_color": int(group_colors[0], 16), \
                "highlight_color": int(group_colors[1], 16)}
        return color_map

    def _getColorsForGroup(self, target_group, config):
        for group in config["semantic_groups"]:
            if group["tag"] == target_group:
                return (group["base_color"], group["highlight_color"])
        print "[-] Failed to get colors for group \"%s\" - you might want to check your semantics file." % target_group
        return (self.default_base_color, self.default_highlight_color)

    def uncolorAll(self):
        """
        Uncolors all instructions of all segments by changing their color to white.
        """
        for seg_ea in self.ida_proxy.Segments():
            for function_address in self.ida_proxy.Functions(self.ida_proxy.SegStart(seg_ea), \
                self.ida_proxy.SegEnd(seg_ea)):
                for block in self.ida_proxy.FlowChart(self.ida_proxy.get_func(function_address)):
                    for head in self.ida_proxy.Heads(block.startEA, block.endEA):
                        self.colorInstruction(head, 0xFFFFFF, refresh=False)
        self.ida_proxy.refresh_idaview_anyway()

    def colorInstruction(self, address, color, refresh=True):
        """
        Colors the instruction at an address with the given color code.
        @param address: address of the instruction to color
        @type address: int
        @param color: color-code to set for the instruction
        @type color: int (0xBBGGRR)
        @param refresh: refresh IDA view to ensure the color shows directly, can be omitted for performance.
        @type refresh: boolean
        """
        self.ida_proxy.SetColor(address, self.ida_proxy.CIC_ITEM, color)
        if refresh:
            self.ida_proxy.refresh_idaview_anyway()

    def colorBasicBlock(self, address, color, refresh=True):
        """
        Colors the basic block containing a target address with the given color code.
        @param address: address an instruction in the basic block to color
        @type address: int
        @param color: color-code to set for the instruction
        @type color: int (0xBBGGRR)
        @param refresh: refresh IDA view to ensure the color shows directly, can be omitted for performance.
        @type refresh: boolean
        """
        function_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(address))
        for block in function_chart:
            if block.startEA <= address < block.endEA:
                for head in self.ida_proxy.Heads(block.startEA, block.endEA):
                    self.colorInstruction(head, color, refresh)

    def getNextColorScheme(self):
        """
        get the next color scheme in the three-cycle "individual/mono/uncolored", where individual is semantic coloring
        @return: next state
        """
        if self.color_state == "individual":
            return "mono"
        elif self.color_state == "mono":
            return "uncolored"
        elif self.color_state == "uncolored":
            return "individual"
        else:
            return "individual"

    def selectHighlightColor(self, tag):
        """
        automatically chooses the highlight color for a tag based on the current color scheme
        @return: (int) a color code
        """
        if self.getNextColorScheme() == "uncolored":
            return 0xFFFFFF
        elif self.getNextColorScheme() == "mono":
            return self.default_highlight_color
        else:
            return self.color_map[tag]["highlight_color"]

    def selectBaseColor(self, tagged_addresses_in_block):
        """
        automatically chooses the base color for a block based on the current color scheme
        @param tagged_addresses_in_block: all tagged addresses in a basic block for which the color shall be chosen
        @type tagged_addresses_in_block: a list of tuples (int, str) containing pairs of instruction addresses and tags
        @return: (int) a color code
        """
        if self.getNextColorScheme() == "uncolored":
            return 0xFFFFFF
        elif self.getNextColorScheme() == "mono":
            return self.default_base_color
        else:
            tags_in_block = [item[1] for item in tagged_addresses_in_block]
            colors_in_block = set([self.color_map[tags_in_block[index]]["base_color"] \
                for index in xrange(len(tags_in_block))])
            if len(colors_in_block) == 1:
                return colors_in_block.pop()
            else:
                return self.default_neutral_color

    def colorize(self, scan_result):
        """
        perform coloring on the IDB, based on a scan performed by SemanticIdentifier
        @param scan_result: result of a scan as performed by SemanticIdentifier
        @type scan_result: a dictionary with key/value entries of the following form: (address, [FunctionContext])
        """
        for function_address in scan_result.keys():
            tagged_addresses_in_function = scan_result[function_address].getAllTaggedAddresses()
            function_chart = self.ida_proxy.FlowChart(self.ida_proxy.get_func(function_address))
            for basic_block in function_chart:
                tagged_addresses_in_block = [(addr, tagged_addresses_in_function[addr]) for addr in \
                    tagged_addresses_in_function.keys() if addr in xrange(basic_block.startEA, basic_block.endEA)]
                if len(tagged_addresses_in_block) > 0:
                    base_color = self.selectBaseColor(tagged_addresses_in_block)
                    self.colorBasicBlock(basic_block.startEA, base_color, refresh=False)
                    for tagged_address in tagged_addresses_in_block:
                        highlight_color = self.selectHighlightColor(tagged_address[1])
                        self.colorInstruction(tagged_address[0], highlight_color, refresh=False)
        self.color_state = self.getNextColorScheme()
        self.ida_proxy.refresh_idaview_anyway()

    def getNextNonFuncInstruction(self, addr):
        next_instruction = addr
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.ida_proxy.find_not_func(next_instruction, self.ida_proxy.SEARCH_DOWN)
            flags = self.ida_proxy.GetFlags(next_instruction)
            if self.ida_proxy.isCode(flags):
                return next_instruction
        return self.ida_proxy.BAD_ADDR

    def convertNonFunctionCode(self):
        self.convertAnyProloguesToFunctions()
        # do a second run to define the rest
        next_instruction = self.ida_proxy.minEA()
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.getNextNonFuncInstruction(next_instruction)
            print("[+] Fixed undefined code to function @ [%08x]" % \
                (next_instruction))
            self.ida_proxy.MakeFunction(next_instruction)
        return

    def convertAnyProloguesToFunctions(self):
        self.convertDataWithPrologueToCode()
        self.convertNonFunctionCodeWithPrologues()

    def convertNonFunctionCodeWithPrologues(self):
        next_instruction = self.ida_proxy.minEA()
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.getNextNonFuncInstruction(next_instruction)
            if self.ida_proxy.GetMnem(next_instruction).startswith("push") and \
                self.ida_proxy.GetOpType(next_instruction, 0) == 1 and \
                self.ida_proxy.GetOperandValue(next_instruction, 0) == 5:
                instruction_after_push = self.getNextNonFuncInstruction(next_instruction)
                if self.ida_proxy.GetMnem(instruction_after_push).startswith("mov") and \
                    self.ida_proxy.GetOpType(instruction_after_push, 0) == 1 and \
                    self.ida_proxy.GetOperandValue(instruction_after_push, 0) == 5 and \
                    self.ida_proxy.GetOpType(instruction_after_push, 1) == 1 and \
                    self.ida_proxy.GetOperandValue(instruction_after_push, 1) == 4:
                        print("[+] Fixed undefined code with function prologue (push ebp; mov ebp, esp) to function " \
                            + "@ [%08x]" % (next_instruction))
                        self.ida_proxy.MakeFunction(next_instruction)

    def convertDataWithPrologueToCode(self):
        current_seg = self.ida_proxy.FirstSeg()
        seg_end = self.ida_proxy.SegEnd(current_seg)
        while current_seg != self.ida_proxy.BAD_ADDR:
            signature_hit = self.ida_proxy.find_binary(current_seg, seg_end, "55 8B EC", 16, 1)
            if signature_hit != self.ida_proxy.BAD_ADDR:
                flags = self.ida_proxy.GetFlags(signature_hit)
                if not self.ida_proxy.isCode(flags):
                    self.ida_proxy.MakeFunction(signature_hit)
                    print("[+] Fixed undefined data with potential function prologue (push ebp; mov ebp, esp) to function " \
                            + "@ [%08x]" % (signature_hit))
                current_seg = signature_hit + 3 + 1
            else:
                current_seg = self.ida_proxy.NextSeg(seg_end)
                if not current_seg == self.ida_proxy.BAD_ADDR:
                    seg_end = self.ida_proxy.SegEnd(current_seg)
예제 #5
0
class DocumentationHelper():
    """
    This class handles instruction coloring.
    """

    # data layout of color maps
    layout_color_map = {
        "tag": {
            "base_color": 0x112233,
            "highlight_color": 0x445566
        }
    }

    def __init__(self, idascope_config):
        print("[|] loading DocumentationHelper")
        self.ida_proxy = IdaProxy()
        # default colors are grey / light red / red
        self.default_neutral_color = 0xCCCCCC
        self.default_base_color = 0xB3B3FF
        self.default_highlight_color = 0x3333FF
        self.color_state = "unknown"
        self.idascope_config = idascope_config
        self._loadConfig(self.idascope_config.semantics_file)
        return

    def _loadConfig(self, config_filename):
        """
        Loads a semantic configuration file and generates a color map from the contained information.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        config_file = open(config_filename, "r")
        config = config_file.read()
        parsed_config = json.loads(config, object_hook=JsonHelper.decode_dict)
        self.default_neutral_color = int(
            parsed_config["default_neutral_color"], 16)
        self.default_base_color = int(parsed_config["default_base_color"], 16)
        self.default_highlight_color = int(
            parsed_config["default_highlight_color"], 16)
        self.color_map = self._generateColorMapFromDefinitions(parsed_config)
        return

    def _generateColorMapFromDefinitions(self, config):
        """
        Internal function to generate a color map from a semantic definitions config file.
        @param definitions: the defintions part of a semantic definitions config file.
        @type definitions: dict
        @return: a dictionary of a color map, see I{layout_color_map} for a reference
        """
        color_map = {}
        for definition in config["semantic_definitions"]:
            # convert text representation of color codes to numbers
            group_colors = self._getColorsForGroup(definition["group"], config)
            color_map[definition["tag"]] = {"base_color": int(group_colors[0], 16), \
                "highlight_color": int(group_colors[1], 16)}
        return color_map

    def _getColorsForGroup(self, target_group, config):
        for group in config["semantic_groups"]:
            if group["tag"] == target_group:
                return (group["base_color"], group["highlight_color"])
        print "[-] Failed to get colors for group \"%s\" - you might want to check your semantics file." % target_group
        return (self.default_base_color, self.default_highlight_color)

    def uncolorAll(self):
        """
        Uncolors all instructions of all segments by changing their color to white.
        """
        for seg_ea in self.ida_proxy.Segments():
            for function_address in self.ida_proxy.Functions(self.ida_proxy.SegStart(seg_ea), \
                self.ida_proxy.SegEnd(seg_ea)):
                for block in self.ida_proxy.FlowChart(
                        self.ida_proxy.get_func(function_address)):
                    for head in self.ida_proxy.Heads(block.startEA,
                                                     block.endEA):
                        self.colorInstruction(head, 0xFFFFFF, refresh=False)
        self.ida_proxy.refresh_idaview_anyway()

    def colorInstruction(self, address, color, refresh=True):
        """
        Colors the instruction at an address with the given color code.
        @param address: address of the instruction to color
        @type address: int
        @param color: color-code to set for the instruction
        @type color: int (0xBBGGRR)
        @param refresh: refresh IDA view to ensure the color shows directly, can be omitted for performance.
        @type refresh: boolean
        """
        self.ida_proxy.SetColor(address, self.ida_proxy.CIC_ITEM, color)
        if refresh:
            self.ida_proxy.refresh_idaview_anyway()

    def colorBasicBlock(self, address, color, refresh=True):
        """
        Colors the basic block containing a target address with the given color code.
        @param address: address an instruction in the basic block to color
        @type address: int
        @param color: color-code to set for the instruction
        @type color: int (0xBBGGRR)
        @param refresh: refresh IDA view to ensure the color shows directly, can be omitted for performance.
        @type refresh: boolean
        """
        function_chart = self.ida_proxy.FlowChart(
            self.ida_proxy.get_func(address))
        for block in function_chart:
            if block.startEA <= address < block.endEA:
                for head in self.ida_proxy.Heads(block.startEA, block.endEA):
                    self.colorInstruction(head, color, refresh)

    def getNextColorScheme(self):
        """
        get the next color scheme in the three-cycle "individual/mono/uncolored", where individual is semantic coloring
        @return: next state
        """
        if self.color_state == "individual":
            return "mono"
        elif self.color_state == "mono":
            return "uncolored"
        elif self.color_state == "uncolored":
            return "individual"
        else:
            return "individual"

    def selectHighlightColor(self, tag):
        """
        automatically chooses the highlight color for a tag based on the current color scheme
        @return: (int) a color code
        """
        if self.getNextColorScheme() == "uncolored":
            return 0xFFFFFF
        elif self.getNextColorScheme() == "mono":
            return self.default_highlight_color
        else:
            return self.color_map[tag]["highlight_color"]

    def selectBaseColor(self, tagged_addresses_in_block):
        """
        automatically chooses the base color for a block based on the current color scheme
        @param tagged_addresses_in_block: all tagged addresses in a basic block for which the color shall be chosen
        @type tagged_addresses_in_block: a list of tuples (int, str) containing pairs of instruction addresses and tags
        @return: (int) a color code
        """
        if self.getNextColorScheme() == "uncolored":
            return 0xFFFFFF
        elif self.getNextColorScheme() == "mono":
            return self.default_base_color
        else:
            tags_in_block = [item[1] for item in tagged_addresses_in_block]
            colors_in_block = set([self.color_map[tags_in_block[index]]["base_color"] \
                for index in xrange(len(tags_in_block))])
            if len(colors_in_block) == 1:
                return colors_in_block.pop()
            else:
                return self.default_neutral_color

    def colorize(self, scan_result):
        """
        perform coloring on the IDB, based on a scan performed by SemanticIdentifier
        @param scan_result: result of a scan as performed by SemanticIdentifier
        @type scan_result: a dictionary with key/value entries of the following form: (address, [FunctionContext])
        """
        for function_address in scan_result.keys():
            tagged_addresses_in_function = scan_result[
                function_address].getAllTaggedAddresses()
            function_chart = self.ida_proxy.FlowChart(
                self.ida_proxy.get_func(function_address))
            for basic_block in function_chart:
                tagged_addresses_in_block = [(addr, tagged_addresses_in_function[addr]) for addr in \
                    tagged_addresses_in_function.keys() if addr in xrange(basic_block.startEA, basic_block.endEA)]
                if len(tagged_addresses_in_block) > 0:
                    base_color = self.selectBaseColor(
                        tagged_addresses_in_block)
                    self.colorBasicBlock(basic_block.startEA,
                                         base_color,
                                         refresh=False)
                    for tagged_address in tagged_addresses_in_block:
                        highlight_color = self.selectHighlightColor(
                            tagged_address[1])
                        self.colorInstruction(tagged_address[0],
                                              highlight_color,
                                              refresh=False)
        self.color_state = self.getNextColorScheme()
        self.ida_proxy.refresh_idaview_anyway()

    def getNextNonFuncInstruction(self, addr):
        next_instruction = addr
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.ida_proxy.find_not_func(
                next_instruction, self.ida_proxy.SEARCH_DOWN)
            flags = self.ida_proxy.GetFlags(next_instruction)
            if self.ida_proxy.isCode(flags):
                return next_instruction
        return self.ida_proxy.BAD_ADDR

    def convertNonFunctionCode(self):
        self.convertAnyProloguesToFunctions()
        # do a second run to define the rest
        next_instruction = self.ida_proxy.minEA()
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.getNextNonFuncInstruction(next_instruction)
            print("[+] Fixed undefined code to function @ [%08x]" % \
                (next_instruction))
            self.ida_proxy.MakeFunction(next_instruction)
        return

    def convertAnyProloguesToFunctions(self):
        self.convertDataWithPrologueToCode()
        self.convertNonFunctionCodeWithPrologues()

    def convertNonFunctionCodeWithPrologues(self):
        next_instruction = self.ida_proxy.minEA()
        while next_instruction != self.ida_proxy.BAD_ADDR:
            next_instruction = self.getNextNonFuncInstruction(next_instruction)
            if self.ida_proxy.GetMnem(next_instruction).startswith("push") and \
                self.ida_proxy.GetOpType(next_instruction, 0) == 1 and \
                self.ida_proxy.GetOperandValue(next_instruction, 0) == 5:
                instruction_after_push = self.getNextNonFuncInstruction(
                    next_instruction)
                if self.ida_proxy.GetMnem(instruction_after_push).startswith("mov") and \
                    self.ida_proxy.GetOpType(instruction_after_push, 0) == 1 and \
                    self.ida_proxy.GetOperandValue(instruction_after_push, 0) == 5 and \
                    self.ida_proxy.GetOpType(instruction_after_push, 1) == 1 and \
                    self.ida_proxy.GetOperandValue(instruction_after_push, 1) == 4:
                    print("[+] Fixed undefined code with function prologue (push ebp; mov ebp, esp) to function " \
                        + "@ [%08x]" % (next_instruction))
                    self.ida_proxy.MakeFunction(next_instruction)

    def convertDataWithPrologueToCode(self):
        current_seg = self.ida_proxy.FirstSeg()
        seg_end = self.ida_proxy.SegEnd(current_seg)
        while current_seg != self.ida_proxy.BAD_ADDR:
            signature_hit = self.ida_proxy.find_binary(current_seg, seg_end,
                                                       "55 8B EC", 16, 1)
            if signature_hit != self.ida_proxy.BAD_ADDR:
                flags = self.ida_proxy.GetFlags(signature_hit)
                if not self.ida_proxy.isCode(flags):
                    self.ida_proxy.MakeFunction(signature_hit)
                    print("[+] Fixed undefined data with potential function prologue (push ebp; mov ebp, esp) to function " \
                            + "@ [%08x]" % (signature_hit))
                current_seg = signature_hit + 3 + 1
            else:
                current_seg = self.ida_proxy.NextSeg(seg_end)
                if not current_seg == self.ida_proxy.BAD_ADDR:
                    seg_end = self.ida_proxy.SegEnd(current_seg)
예제 #6
0
class SemanticIdentifier():
    """
    A module to analyze and explore an IDB for semantics. For a set of API names, references to these
    are identified and used for creating context and allowing tagging of them.
    """
    def __init__(self, idascope_config):
        print("[|] loading SemanticIdentifier")
        self.os = os
        self.re = re
        self.time = time
        self.ida_proxy = IdaProxy()
        self.FunctionContext = FunctionContext
        self.FunctionContextFilter = FunctionContextFilter
        self.CallContext = CallContext
        self.ParameterContext = ParameterContext
        # fields
        self.semantics = {}
        self.active_semantics = {}
        self.renaming_seperator = "_"
        self.semantic_groups = []
        self.semantic_definitions = []
        self.real_api_names = {}
        self.last_scan_result = {}
        self.idascope_config = idascope_config
        self._getRealApiNames()
        self._loadSemantics(self.idascope_config)
        return

    def _cbEnumImports(self, addr, name, ordinal):
        if name:
            self.real_api_names[name] = self.ida_proxy.Name(addr)
        return True

    def _getRealApiNames(self):
        num_imports = self.ida_proxy.get_import_module_qty()
        for i in xrange(0, num_imports):
            self.ida_proxy.enum_import_names(i, self._cbEnumImports)

    def lookupRealApiName(self, api_name):
        if api_name in self.real_api_names:
            return self.real_api_names[api_name]
        else:
            return api_name

    def lookupDisplayApiName(self, real_api_name):
        """ returns the key by given value of self.real_api_names (basically inverted dictionary)
        """
        name = real_api_name
        for display_name in self.real_api_names:
            if real_api_name == self.real_api_names[display_name] \
                    and display_name in self.real_api_names[display_name]:
                name = display_name
        return name

    def _loadSemantics(self, config):
        """
        Loads a semantic configuration file and collects all definitions from it.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        for filename in [
                fn for fn in self.os.listdir(config.semantics_folder)
                if fn.endswith(".json")
        ]:
            loaded_file = self._loadSemanticsFile(config.semantics_folder +
                                                  self.os.sep + filename)
            self.semantics[loaded_file["name"]] = loaded_file
        if config.inspection_default_semantics in self.semantics:
            self._setSemantics(config.inspection_default_semantics)
        elif len(self.semantics) > 0:
            self._setSemantics(sorted(self.semantics.keys())[0])
        else:
            self._setSemantics("")
        return

    def _loadSemanticsFile(self, semantics_filename):
        """
        Loads a semantic configuration file and collects all definitions from it.
        @param config_filename: filename of a semantic configuration file
        @type config_filename: str
        """
        semantics_file = open(semantics_filename, "r")
        semantics = semantics_file.read()
        return json.loads(semantics, object_hook=JsonHelper.decode_dict)

    def _setSemantics(self, semantics_entry):
        semantics_content = {}
        if semantics_entry in self.semantics:
            semantics_content = self.semantics[semantics_entry]
            self.renaming_seperator = semantics_content["renaming_seperator"]
            self.semantic_groups = semantics_content["semantic_groups"]
            self.semantic_definitions = semantics_content[
                "semantic_definitions"]
            self.active_semantics = semantics_content
        else:
            self.renaming_seperator = "_"
            self.semantic_groups = []
            self.semantic_definitions = []
            self.active_semantics = {"name": "none"}
        self.scanByReferences()

    def getSemanticsNames(self):
        return sorted(self.semantics.keys())

    def getActiveSemanticsName(self):
        return self.active_semantics["name"]

    def calculateNumberOfBasicBlocksForFunctionAddress(self, function_address):
        """
        Calculates the number of basic blocks for a given function by walking its FlowChart.
        @param function_address: function address to calculate the block count for
        @type function_address: int
        """
        number_of_blocks = 0
        try:
            func_chart = self.ida_proxy.FlowChart(
                self.ida_proxy.get_func(function_address))
            for block in func_chart:
                number_of_blocks += 1
        except:
            pass
        return number_of_blocks

    def getNumberOfBasicBlocksForFunctionAddress(self, address):
        """
        returns the number of basic blocks for the function containing the queried address,
        based on the value stored in the last scan result.

        If the number of basic blocks for this function has never been calculated, zero is returned.
        @param function_address: function address to get the block count for
        @type function_address: int
        @return: (int) The number of blocks in th e function
        """
        number_of_blocks = 0
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            number_of_blocks = self.last_scan_result[
                function_address].number_of_basic_blocks
        return number_of_blocks

    def scan(self):
        """
        Scan the whole IDB with all available techniques.
        """
        self.scanByReferences()
        self.scanDeep()

    def scanByReferences(self):
        """
        Scan by references to API names, based on the definitions loaded from the config file.
        This is highly efficient because we only touch places in the IDB that actually have references
        to our API names of interest.
        """
        print(
            "  [/] SemanticIdentifier: Starting (fast) scan by references of function semantics."
        )
        time_before = self.time.time()
        self.last_scan_result = {}
        for semantic_tag in self.semantic_definitions:
            for api_name in semantic_tag["api_names"]:
                real_api_name = self.lookupRealApiName(api_name)
                api_address = self.ida_proxy.LocByName(real_api_name)
                for ref in self._getAllRefsTo(api_address):
                    function_ctx = self._getFunctionContext(ref)
                    function_ctx.has_tags = True
                    call_ctx = self.CallContext()
                    call_ctx.called_function_name = api_name
                    call_ctx.real_called_function_name = real_api_name
                    call_ctx.address_of_call = ref
                    call_ctx.called_address = api_address
                    call_ctx.tag = semantic_tag["tag"]
                    call_ctx.group = semantic_tag["group"]
                    call_ctx.parameter_contexts = self._resolveApiCall(
                        call_ctx)
                    function_ctx.call_contexts.append(call_ctx)
        print("  [\\] Analysis took %3.2f seconds." %
              (self.time.time() - time_before))

    def _getAllRefsTo(self, addr):
        code_ref_addrs = [ref for ref in self.ida_proxy.CodeRefsTo(addr, 0)]
        data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsTo(addr)]
        return iter(set(code_ref_addrs).union(set(data_ref_addrs)))

    def _getNumRefsTo(self, addr):
        return sum([1 for ref in self._getAllRefsTo(addr)])

    def _getAllRefsFrom(self, addr, code_only=False):
        code_ref_addrs = [ref for ref in self.ida_proxy.CodeRefsFrom(addr, 0)]
        data_ref_addrs = []
        if code_only:
            # only consider data references that lead to a call near/far (likely imports)
            data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsFrom(addr) if \
                self.ida_proxy.GetFlags(ref) & (self.ida_proxy.FL_CN | self.ida_proxy.FL_CF)]
        else:
            data_ref_addrs = [ref for ref in self.ida_proxy.DataRefsFrom(addr)]
        return iter(set(code_ref_addrs).union(set(data_ref_addrs)))

    def _getFunctionContext(self, addr):
        """
        Create or return an existing FunctionContext for the given address in the current scan result.
        @param func_addr: address to create a FunctionContext for
        @type func_addr: int
        @return: (FunctionContext) A reference to the corresponding function context
        """
        function_ctx = None
        function_address = self.ida_proxy.LocByName(
            self.ida_proxy.GetFunctionName(addr))
        if function_address not in self.last_scan_result.keys():
            function_ctx = self.FunctionContext()
            function_ctx.function_address = function_address
            function_ctx.function_name = self.ida_proxy.GetFunctionName(
                function_address)
            function_ctx.has_dummy_name = (self.ida_proxy.GetFlags(function_address) & \
                self.ida_proxy.FF_LABL) > 0
            self.last_scan_result[function_ctx.function_address] = function_ctx
        else:
            function_ctx = self.last_scan_result[function_address]
        return function_ctx

    def scanDeep(self):
        """
        Perform a full enumeration of all instructions,
        gathering information like number of instructions, number of basic blocks,
        references to and from functions etc.
        """
        print(
            "  [/] SemanticIdentifier: Starting deep scan of function semantics."
        )
        time_before = self.time.time()
        for function_ea in self.ida_proxy.Functions():
            function_chart = self.ida_proxy.FlowChart(
                self.ida_proxy.get_func(function_ea))
            num_blocks = 0
            num_instructions = 0
            xrefs_from = []
            calls_from = []
            function_ctx = self._getFunctionContext(function_ea)
            for block in function_chart:
                num_blocks += 1
                for instruction in self.ida_proxy.Heads(
                        block.startEA, block.endEA):
                    num_instructions += 1
                    if self.ida_proxy.isCode(
                            self.ida_proxy.GetFlags(instruction)):
                        for ref in self._getAllRefsFrom(instruction):
                            if self.ida_proxy.GetMnem(instruction) == "call":
                                calls_from.append(ref)
                            xrefs_from.append(ref)
            function_ctx.calls_from.update(calls_from)
            function_ctx.number_of_xrefs_to = self._getNumRefsTo(function_ea)
            function_ctx.xrefs_from.update(xrefs_from)
            function_ctx.number_of_xrefs_from = len(xrefs_from)
            function_ctx.number_of_basic_blocks = num_blocks
            function_ctx.number_of_instructions = num_instructions
        print("  [\\] Analysis took %3.2f seconds." %
              (self.time.time() - time_before))

    def getFunctionAddressForAddress(self, address):
        """
        Get a function address containing the queried address.
        @param address: address to check the function address for
        @type address: int
        @return: (int) The start address of the function containing this address
        """
        return self.ida_proxy.LocByName(
            self.ida_proxy.GetFunctionName(address))

    def calculateNumberOfFunctions(self):
        """
        Calculate the number of functions in all segments.
        @return: (int) the number of functions found.
        """
        number_of_functions = 0
        for seg_ea in self.ida_proxy.Segments():
            for function_ea in self.ida_proxy.Functions(
                    self.ida_proxy.SegStart(seg_ea),
                    self.ida_proxy.SegEnd(seg_ea)):
                number_of_functions += 1
        return number_of_functions

    def calculateNumberOfTaggedFunctions(self):
        """
        Calculate the number of functions in all segments that have been tagged.
        @return: (int) the number of functions found.
        """
        return len(
            self.getFunctionAddresses(self.createFunctionContextFilter()))

    def getFunctionAddresses(self, context_filter):
        """
        Get all function address that have been covered by the last scanning.
        @param dummy_only: only return functions with dummy names
        @type dummy_only: bool
        @param tag_only: only return tag functions
        @type tag_only: bool
        @return: (list of int) The addresses of covered functions.
        """
        all_addresses = self.last_scan_result.keys()
        filtered_addresses = []
        if context_filter.display_all:
            filtered_addresses = all_addresses
        elif context_filter.display_tags:
            for address in all_addresses:
                enabled_tags = [tag[0] for tag in context_filter.enabled_tags]
                if len(
                        set(self.last_scan_result[address].getTags())
                        & set(enabled_tags)) > 0:
                    filtered_addresses.append(address)
        elif context_filter.display_groups:
            for address in all_addresses:
                enabled_groups = [
                    group[0] for group in context_filter.enabled_groups
                ]
                if len(
                        set(self.last_scan_result[address].getGroups())
                        & set(enabled_groups)) > 0:
                    filtered_addresses.append(address)
        # filter additionals
        if context_filter.isDisplayTagOnly():
            filtered_addresses = [
                addr for addr in filtered_addresses
                if self.last_scan_result[addr].has_tags
            ]
        if context_filter.isDisplayDummyOnly():
            filtered_addresses = [
                addr for addr in filtered_addresses
                if self.last_scan_result[addr].has_dummy_name
            ]
        return filtered_addresses

    def getTags(self):
        """
        Get all the tags that have been covered by the last scanning.
        @return (list of str) The tags found.
        """
        tags = []
        for function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[
                    function_address].call_contexts:
                if call_ctx.tag not in tags:
                    tags.append(call_ctx.tag)
        return tags

    def getGroups(self):
        """
        Get all the groups that have been covered by tags in the last scanning.
        @return (list of str) The groups found.
        """
        tag_to_group_mapping = self._createTagToGroupMapping()
        groups = []
        for function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[
                    function_address].call_contexts:
                if tag_to_group_mapping[call_ctx.tag] not in groups:
                    groups.append(tag_to_group_mapping[call_ctx.tag])
        return groups

    def _createTagToGroupMapping(self):
        mapping = {}
        for definition in self.semantic_definitions:
            mapping[definition["tag"]] = definition["group"]
        return mapping

    def getTagsForFunctionAddress(self, address):
        """
        Get all tags found for the function containing the queried address.
        @param address: address in the target function
        @type address: int
        @return: (list of str) The tags for the function containing the queried address
        """
        tags = []
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            for call_ctx in self.last_scan_result[
                    function_address].call_contexts:
                if call_ctx.tag not in tags:
                    tags.append(call_ctx.tag)
        return tags

    def getFieldCountForFunctionAddress(self, query, address):
        """
        Get the number of occurrences for a certain field for the function containing the queried address.
        @param query: a tuple (type, name), where type is additional, tag, or group and name the field being queried.
        @type query: tuple
        @param address: address in the target function
        @type address: int
        @return: (int) The number of occurrences for this tag in the function
        """
        function_address = self.getFunctionAddressForAddress(address)
        return self.last_scan_result[function_address].getCountForField(query)

    def getTaggedApisForFunctionAddress(self, address):
        """
        Get all call contexts for the function containing the queried address.
        @param address: address in the target function
        @type address: int
        @return: (list of CallContext data objects) The call contexts identified by the scanning of this function
        """
        function_address = self.getFunctionAddressForAddress(address)
        if function_address in self.last_scan_result.keys():
            all_call_ctx = self.last_scan_result[
                function_address].call_contexts
            return [
                call_ctx for call_ctx in all_call_ctx if call_ctx.tag != ""
            ]

    def getAddressTagPairsOrderedByFunction(self):
        """
        Get all call contexts for all functions
        @return: a dictionary with key/value entries of the following form: (function_address,
                 dict((call_address, tag)))
        """
        functions_and_tags = {}
        for function in self.getIdentifiedFunctionAddresses():
            call_contexts = self.getTaggedApisForFunctionAddress(function)
            if function not in functions_and_tags.keys():
                functions_and_tags[function] = {}
            for call_ctx in call_contexts:
                functions_and_tags[function][
                    call_ctx.address_of_call] = call_ctx.tag
        return functions_and_tags

    def getFunctionsToRename(self):
        """
        Get all functions that can be renamed according to the last scan result. Only functions with the standard
        IDA name I{sub_[0-9A-F]+} will be considered for renaming.
        @return: a list of dictionaries, each consisting of three tuples: ("old_function_name", str), \
                 ("new_function_name", str), ("function_address", int)
        """
        functions_to_rename = []
        for function_address_to_tag in self.last_scan_result.keys():
            new_function_name = self.last_scan_result[
                function_address_to_tag].function_name
            # has the function still a dummy name?
            if self.ida_proxy.GetFlags(
                    function_address_to_tag) & self.ida_proxy.FF_LABL > 0:
                tags_for_function = self.getTagsForFunctionAddress(
                    function_address_to_tag)
                for tag in sorted(tags_for_function, reverse=True):
                    if tag not in new_function_name:
                        new_function_name = tag + self.renaming_seperator + new_function_name
                functions_to_rename.append({"old_function_name": \
                    self.last_scan_result[function_address_to_tag].function_name, "new_function_name": \
                    new_function_name, "function_address": function_address_to_tag})
        return functions_to_rename

    def renameFunctions(self):
        """
        Perform the renaming of functions according to the last scan result.
        """
        for function in self.getFunctionsToRename():
            if function["old_function_name"] == self.ida_proxy.GetFunctionName(
                    function["function_address"]):
                self.ida_proxy.MakeNameEx(function["function_address"], function["new_function_name"], \
                    self.ida_proxy.SN_NOWARN)

    def renamePotentialWrapperFunctions(self):
        """
        contributed by Branko Spasojevic.
        """
        num_wrappers_renamed = 0
        for seg_ea in self.ida_proxy.Segments():
            for func_ea in self.ida_proxy.Functions(
                    self.ida_proxy.SegStart(seg_ea),
                    self.ida_proxy.SegEnd(seg_ea)):
                if (self.ida_proxy.GetFlags(func_ea) & 0x8000) != 0:
                    nr_calls, w_name = self._checkWrapperHeuristics(func_ea)
                    if nr_calls == 1 and len(w_name) > 0:
                        rval = False
                        name_suffix = 0
                        while rval == False:
                            if name_suffix > 40:
                                print("[!] Potentially more than 50 wrappers for function %s, " \
                                    "please report this IDB ;)" % w_name)
                                break
                            demangled_name = self.ida_proxy.Demangle(
                                w_name,
                                self.ida_proxy.GetLongPrm(
                                    self.ida_proxy.INF_SHORT_DN))
                            if demangled_name != None and demangled_name != w_name:
                                f_name = w_name + '_w' + str(name_suffix)
                            elif name_suffix > 0:
                                f_name = w_name + '_w' + str(name_suffix)
                            else:
                                f_name = w_name + '_w0'
                            name_suffix += 1
                            rval = self.ida_proxy.MakeNameEx(func_ea, f_name, \
                                self.ida_proxy.SN_NOCHECK | self.ida_proxy.SN_NOWARN)
                        if rval == True:
                            print("[+] Identified and renamed potential wrapper @ [%08x] to [%s]" % \
                                (func_ea, f_name))
                            num_wrappers_renamed += 1
        print("[+] Renamed %d functions with their potentially wrapped name." %
              num_wrappers_renamed)

    def _checkWrapperHeuristics(self, func_ea):
        """
        Helps renamePotentialWrapperFunctions() to decide whether the function analyzed is a wrapper or not.
        """
        nr_calls = 0
        w_name = ""
        # Heuristic: wrappers are likely short
        func_end = self.ida_proxy.GetFunctionAttr(func_ea,
                                                  self.ida_proxy.FUNCATTR_END)
        if (func_end - func_ea) > 0 and (func_end - func_ea) < 0x40:
            return (0, "")
        # Heuristic: wrappers shall only have a single reference, ideally to a library function.
        for i_ea in self.ida_proxy.FuncItems(func_ea):
            # long jumps don't occur in wrappers considered by this code.
            if self.ida_proxy.GetMnem(i_ea) == 'jmp' \
                and (func_ea > self.ida_proxy.GetOperandValue(i_ea,0) \
                    or func_end < self.ida_proxy.GetOperandValue(i_ea,0)):
                nr_calls += 2
            # checks if call is not memory reference
            if self.ida_proxy.GetMnem(i_ea) == 'call':
                nr_calls += 1
                if self.ida_proxy.GetOpType(i_ea,0) != 2 \
                    and self.ida_proxy.GetOpType(i_ea,0) != 6 \
                        and self.ida_proxy.GetOpType(i_ea,0) != 7:
                    nr_calls += 2
                if nr_calls > 1:
                    break
                call_dst = list(self.ida_proxy.CodeRefsFrom(i_ea, 0))
                if len(call_dst) == 0:
                    continue
                call_dst = call_dst[0]
                if (self.ida_proxy.GetFunctionFlags(call_dst) & self.ida_proxy.FUNC_LIB) != 0 or \
                    (self.ida_proxy.GetFlags(func_ea) & self.ida_proxy.FF_LABL) == 0:
                    w_name = self.ida_proxy.Name(call_dst)
        return (nr_calls, w_name)

    def getParametersForCallAddress(self, call_address):
        """
        Get the parameters for the given address of a function call.
        @param call_address: address of the target call to inspect
        @type call_address: int
        @return: a list of ParameterContext data objects.
        """
        target_function_address = self.ida_proxy.LocByName(
            self.ida_proxy.GetFunctionName(call_address))
        all_tagged_apis_in_function = self.getTaggedApisForFunctionAddress(
            target_function_address)
        for api in all_tagged_apis_in_function:
            if api.address_of_call == call_address:
                return self._resolveApiCall(api)
        return []

    def _resolveApiCall(self, call_context):
        """
        Resolve the parameters for an API calls based on a call context for this API call.
        @param call_context: the call context to get the parameter information for
        @type call_context: a CallContext data object
        @return: a list of ParameterContext data objects.
        """
        resolved_api_parameters = []
        api_signature = self._getApiSignature(
            call_context.real_called_function_name)
        push_addresses = self._getPushAddressesBeforeTargetAddress(
            call_context.address_of_call)
        resolved_api_parameters = self._matchPushAddressesToSignature(
            push_addresses, api_signature)
        return resolved_api_parameters

    def _matchPushAddressesToSignature(self, push_addresses, api_signature):
        """
        Combine the results of I{_getPushAddressesBeforeTargetAddress} and I{_getApiSignature} in order to
        produce a list of ParameterContext data objects.
        @param push_addresses: the identified push addresses before a function call that shall be matched to a function
                               signature
        @type push_addresses: a list of int
        @param api_signature: information about a function definition with
                              parameter names, types, and so on.
        @type api_signature: a dictionary with the layout as returned by I{_getApiSignature}
        @return: a list of ParameterContext data objects.
        """
        matched_parameters = []
        # TODO:
        # upgrade this feature with data flow analysis to resolve parameters with higher precision
        api_num_params = len(api_signature["parameters"])
        push_addresses = push_addresses[-api_num_params:]
        # TODO:
        # There might be the case where we identify less pushed parameters than required by the function
        # signature. Thus we calculate a "parameter discrepancy" that we use to adjust our enumeration index
        # so that the last n parameters get matched correctly. This is a temporary fix and might be solved later on.
        parameter_discrepancy = len(push_addresses) - api_num_params
        for index, param in enumerate(api_signature["parameters"],
                                      start=parameter_discrepancy):
            param_ctx = self.ParameterContext()
            param_ctx.parameter_type = param["type"]
            param_ctx.parameter_name = param["name"]
            if (parameter_discrepancy != 0) and (index < 0):
                param_ctx.valid = False
            else:
                param_ctx.push_address = push_addresses[index]
                param_ctx.ida_operand_type = self.ida_proxy.GetOpType(
                    push_addresses[index], 0)
                param_ctx.ida_operand_value = self.ida_proxy.GetOperandValue(
                    push_addresses[index], 0)
                param_ctx.value = param_ctx.ida_operand_value
            matched_parameters.append(param_ctx)
        return matched_parameters

    def _getApiSignature(self, api_name):
        """
        Get the signature for a function by using IDA's I{GetType()}. The string is then parsed with a Regex and
        returned as a dictionary.
        @param api_name: name of the API / function to get type information for
        @type api_name: str
        @return: a dictionary with key/value entries of the following form: ("return_type", str),
                 ("parameters", [dict(("type", str), ("name", str))])
        """
        api_signature = {"api_name": api_name, "parameters": []}
        api_location = self.ida_proxy.LocByName(api_name)
        type_def = self.ida_proxy.GetType(api_location)
        function_signature_regex = r"(?P<return_type>[\w\s\*]+)\((?P<parameters>[,\.\*\w\s]*)\)"
        result = self.re.match(function_signature_regex, type_def)
        if result is not None:
            api_signature["return_type"] = result.group("return_type")
            if len(result.group("parameters")) > 0:
                for parameter in result.group("parameters").split(","):
                    type_and_name = {}
                    type_and_name["type"] = parameter[:parameter.
                                                      rfind(" ")].strip()
                    type_and_name["name"] = parameter[parameter.
                                                      rfind(" "):].strip()
                    api_signature["parameters"].append(type_and_name)
        else:
            print ("[-] SemanticIdentifier._getApiSignature: No API/function signature for \"%s\" @ 0x%x available. " \
            + "(non-critical)") % (api_name, api_location)
        # TODO:
        # here should be a check for the calling convention
        # currently, list list is simply reversed to match the order parameters are pushed to the stack
        api_signature["parameters"].reverse()
        return api_signature

    def _getPushAddressesBeforeTargetAddress(self, address):
        """
        Get the addresses of all push instructions in the basic block preceding the given address.
        @param address: address to get the push addresses for.
        @type address: int
        @return: a list of int
        """
        push_addresses = []
        function_chart = self.ida_proxy.FlowChart(
            self.ida_proxy.get_func(address))
        for block in function_chart:
            if block.startEA <= address < block.endEA:
                for instruction_addr in self.ida_proxy.Heads(
                        block.startEA, block.endEA):
                    if self.ida_proxy.GetMnem(instruction_addr) == "push":
                        push_addresses.append(instruction_addr)
                    if instruction_addr >= address:
                        break
        return push_addresses

    def createFunctionGraph(self, func_address):
        graph = {"root": func_address, "nodes": {}}
        unexplored = set()
        if func_address in self.last_scan_result.keys():
            graph["nodes"][func_address] = self.last_scan_result[
                func_address].calls_from
            unexplored = set(self.last_scan_result[func_address].calls_from)
            while len(unexplored) > 0:
                current_function = unexplored.pop()
                if current_function in graph["nodes"].keys(
                ) or current_function not in self.last_scan_result.keys():
                    continue
                else:
                    graph["nodes"][current_function] = self.last_scan_result[
                        current_function].calls_from
                    new_functions = \
                        set(self.last_scan_result[current_function].calls_from).difference(set(graph["nodes"].keys()))
                    unexplored.update(new_functions)
        return graph

    def createFunctionContextFilter(self):
        """
        Create a function filter, containing only those tags/groups that have been identified within the last scan.
        """
        context_filter = self.FunctionContextFilter()
        context_filter.tags = sorted([(tag, tag, tag)
                                      for tag in self.getTags()])
        context_filter.enabled_tags = context_filter.tags
        context_filter.groups = sorted([(group, group, group)
                                        for group in self.getGroups()])
        context_filter.enabled_groups = context_filter.groups
        return context_filter

    def getLastScanResult(self):
        """
        Get the last scan result as retrieved by I{scanByReferences}.
        @return: a dictionary with key/value entries of the following form: (function_address, FunctionContext)
        """
        return self.last_scan_result

    def printLastScanResult(self):
        """
        nicely print the last scan result (mostly used for debugging)
        """
        for function_address in self.last_scan_result.keys():
            print ("0x%x - %s -> ") % (function_address, self.ida_proxy.GetFunctionName(function_address)) \
                + ", ".join(self.getTagsForFunctionAddress(function_address))
            for call_ctx in self.last_scan_result[
                    function_address].call_contexts:
                print("    0x%x - %s (%s)") % (call_ctx.address_of_call,
                                               call_ctx.called_function_name,
                                               call_ctx.tag)