Пример #1
0
def datafilter(schema: SortedDict, datapool: DataFrame,
               **kwargs) -> List[DataFrame]:
    """Splits a DataFrame based on the value of applicable columns.

    Each DataFrame object in the returned list will have a single value for
    those columns contained in the schema.

    Args:
        schema: Column names to use for filtering.
        datapool: Data to filter.

    Returns:
        The filtered DataFrame objects. An empty list is returned if no schema
        values could be found.
    """
    result: List[DataFrame] = kwargs["filtered"].copy(
    ) if "filtered" in kwargs else []

    this_schema: SortedDict = schema.copy()
    try:
        _, header = this_schema.popitem(index=0)
    except KeyError:
        result.append(datapool)
        return result

    if header not in datapool.columns:
        return result

    for value in datapool.get(header).drop_duplicates().values:
        new_pool: DataFrame = datapool.loc[datapool.get(header) == value]
        result = datafilter(this_schema, new_pool, filtered=result)
    return result
Пример #2
0
def _get_expected_paths(path: str,
                        schema: SortedDict,
                        subset: DataFrame,
                        filename: str,
                        path_list=None) -> SortedList:
    # prevent mutable default parameter
    if path_list is None:
        path_list = SortedList()

    this_schema = schema.copy()
    header = None
    try:
        _, header = this_schema.popitem(last=False)
    except KeyError:
        path_list.add(os.path.join(path, filename))
        return path_list

    if header not in subset.columns:
        return path_list

    for value in subset.get(header).drop_duplicates().values:
        new_subset = subset.loc[subset.get(header) == value]
        value = value.lower().replace(' ', '_')
        if value[-1] == '.':
            value = value[:-1]
        path_list = _get_expected_paths(os.path.join(path, value),
                                        this_schema,
                                        new_subset,
                                        filename,
                                        path_list=path_list)
    return path_list
def test_copy():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    dup = temp.copy()
    assert len(temp) == 26
    dup.clear()
    assert len(temp) == 0
def test_copy():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    dup = temp.copy()
    assert len(temp) == 26
    dup.clear()
    assert len(temp) == 0
Пример #5
0
class ResolverContext:
    def __init__(self):
        self.__factory = SortedDict()

    def __str__(self) -> str:
        return str(self.__factory)

    def register(self, typename_prefix, resolver):
        self.__factory[typename_prefix] = resolver

    def run(self, obj, **kw):
        typename = obj.meta.typename
        prefix, resolver = find_most_precise_match(typename, self.__factory)
        vineyard_client = kw.pop('__vineyard_client', None)
        if prefix:
            resolver_func_sig = inspect.getfullargspec(resolver)
            if resolver_func_sig.varkw is not None:
                value = resolver(obj, resolver=self, **kw)
            else:
                # don't pass the `**kw`.
                if 'resolver' in resolver_func_sig.args:
                    value = resolver(obj, resolver=self)
                else:
                    value = resolver(obj)
            if value is None:
                # if the obj has been resolved by pybind types, and there's no proper
                # resolver, it shouldn't be an error
                if type(obj) is not Object:
                    return obj

                raise RuntimeError(
                    'Unable to construct the object using resolver: '
                    'typename is %s, resolver is %s' % (obj.meta.typename, resolver)
                )

            # associate a reference to the base C++ object
            try:
                setattr(value, '__vineyard_ref', obj)
                setattr(value, '__vineyard_client', vineyard_client)

                # register methods
                get_current_drivers().resolve(value, obj.typename)
            except AttributeError:
                pass

            return value
        # keep it as it is
        return obj

    def __call__(self, obj, **kw):
        return self.run(obj, **kw)

    def extend(self, resolvers=None):
        resolver = ResolverContext()
        resolver.__factory = self.__factory.copy()
        if resolvers:
            resolver.__factory.update(resolvers)
        return resolver
Пример #6
0
def _datafilter(schema: SortedDict,
                datapool: DataFrame,
                filtered=None) -> list:
    """The `filtered` parameter should only be used internally."""
    # prevent mutable default parameter
    if filtered is None:
        filtered = []

    this_schema = schema.copy()
    header = None
    try:
        _, header = this_schema.popitem(last=False)
    except KeyError:
        filtered.append(datapool)
        return filtered

    if header not in datapool.columns:
        return filtered

    for value in datapool.get(header).drop_duplicates().values:
        new_pool = datapool.loc[datapool.get(header) == value]
        filtered = _datafilter(this_schema, new_pool, filtered=filtered)
    return filtered
Пример #7
0
class PatchManager(KnowledgeBasePlugin):
    """
    A placeholder-style implementation for a binary patch manager. This class should be significantly changed in the
    future when all data about loaded binary objects are loaded into angr knowledge base from CLE. As of now, it only
    stores byte-level replacements. Other angr components may choose to use or not use information provided by this
    manager. In other words, it is not transparent.

    Patches should not overlap, but it's user's responsibility to check for and avoid overlapping patches.
    """
    def __init__(self, kb):
        super().__init__()

        self._patches = SortedDict()
        self._kb = kb

    def add_patch(self, addr, new_bytes):
        self._patches[addr] = Patch(addr, new_bytes)

    def remove_patch(self, addr):
        if addr in self._patches:
            del self._patches[addr]

    def patch_addrs(self):
        return self._patches.keys()

    def get_patch(self, addr):
        """
        Get patch at the given address.

        :param int addr:    The address of the patch.
        :return:            The patch if there is one starting at the address, or None if there isn't any.
        :rtype:             Patch or None
        """
        return self._patches.get(addr, None)

    def get_all_patches(self, addr, size):
        """
        Retrieve all patches that cover a region specified by [addr, addr+size).

        :param int addr:    The address of the beginning of the region.
        :param int size:    Size of the region.
        :return:            A list of patches.
        :rtype:             list
        """
        patches = []
        for patch_addr in self._patches.irange(maximum=addr + size - 1,
                                               reverse=True):
            p = self._patches[patch_addr]
            if self.overlap(p.addr, p.addr + len(p), addr, addr + size):
                patches.append(p)
            else:
                break
        return patches[::-1]

    def keys(self):
        return self._patches.keys()

    def items(self):
        return self._patches.items()

    def values(self):
        return self._patches.values()

    def copy(self):
        o = PatchManager(self._kb)
        o._patches = self._patches.copy()

    @staticmethod
    def overlap(a0, a1, b0, b1):
        return a0 <= b0 < a1 or a0 <= b1 < a1 or b0 <= a0 < b1
Пример #8
0
class TreePage(BasePage):
    """
    Page object, implemented with a sorted dict. Who knows what's underneath!
    """

    def __init__(self, *args, **kwargs):
        storage = kwargs.pop("storage", None)
        super(TreePage, self).__init__(*args, **kwargs)
        self._storage = SortedDict() if storage is None else storage

    def keys(self):
        if len(self._storage) == 0:
            return set()
        else:
            return set.union(*(set(range(*self._resolve_range(mo))) for mo in self._storage.itervalues()))

    def replace_mo(self, state, old_mo, new_mo):
        start, end = self._resolve_range(old_mo)
        for key in self._storage.irange(start, end-1):
            val = self._storage[key]
            if val is old_mo:
                #assert new_mo.includes(a)
                self._storage[key] = new_mo

    def store_overwrite(self, state, new_mo, start, end):
        # iterate over each item we might overwrite
        # track our mutations separately since we're in the process of iterating
        deletes = []
        updates = { start: new_mo }

        for key in self._storage.irange(maximum=end-1, reverse=True):
            old_mo = self._storage[key]

            # make sure we aren't overwriting all of an item that overlaps the end boundary
            if end < self._page_addr + self._page_size and end not in updates and old_mo.includes(end):
                updates[end] = old_mo

            # we can't set a minimum on the range because we need to do the above for
            # the first object before start too
            if key < start:
                break

            # delete any key that falls within the range
            deletes.append(key)

        #assert all(m.includes(i) for i,m in updates.items())

        # perform mutations
        for key in deletes:
            del self._storage[key]

        self._storage.update(updates)

    def store_underwrite(self, state, new_mo, start, end):
        # track the point that we need to write up to
        last_missing = end - 1
        # track also updates since we can't update while iterating
        updates = {}

        for key in self._storage.irange(maximum=end-1, reverse=True):
            mo = self._storage[key]

            # if the mo stops
            if mo.base <= last_missing and not mo.includes(last_missing):
                updates[max(mo.last_addr+1, start)] = new_mo
            last_missing = mo.base - 1

            # we can't set a minimum on the range because we need to do the above for
            # the first object before start too
            if last_missing < start:
                break

        # if there are no memory objects <= start, we won't have filled start yet
        if last_missing >= start:
            updates[start] = new_mo

        #assert all(m.includes(i) for i,m in updates.items())

        self._storage.update(updates)

    def load_mo(self, state, page_idx):
        """
        Loads a memory object from memory.

        :param page_idx: the index into the page
        :returns: a tuple of the object
        """

        try:
            key = next(self._storage.irange(maximum=page_idx, reverse=True))
        except StopIteration:
            return None
        else:
            return self._storage[key]

    def load_slice(self, state, start, end):
        """
        Return the memory objects overlapping with the provided slice.

        :param start: the start address
        :param end: the end address (non-inclusive)
        :returns: tuples of (starting_addr, memory_object)
        """
        keys = list(self._storage.irange(start, end-1))
        if not keys or keys[0] != start:
            try:
                key = next(self._storage.irange(maximum=start, reverse=True))
            except StopIteration:
                pass
            else:
                if self._storage[key].includes(start):
                    items.insert(0, key)
        return [(key, self._storage[key]) for key in keys]

    def _copy_args(self):
        return { 'storage': self._storage.copy() }
Пример #9
0
def check_split_by_blast(transcript, cds_boundaries):

    """
    This method verifies if a transcript with multiple ORFs has support by BLAST to
    NOT split it into its different components.

    The minimal overlap between ORF and HSP is defined inside the JSON at the key
        ["chimera_split"]["blast_params"]["minimal_hsp_overlap"]
    basically, we consider a HSP a hit only if the overlap is over a certain threshold
    and the HSP evalue under a certain threshold.

    The split by CDS can be executed in three different ways - PERMISSIVE, LENIENT, STRINGENT:

    - PERMISSIVE: split if two CDSs do not have hits in common,
    even when one or both do not have a hit at all.
    - STRINGENT: split only if two CDSs have hits and none
    of those is in common between them.
    - LENIENT: split if *both* lack hits, OR *both* have hits and none
    of those is in common.

    :param transcript: the transcript instance
    :type transcript: Mikado.loci_objects.transcript.Transcript
    :param cds_boundaries:
    :return: cds_boundaries
    :rtype: dict
    """

    # Establish the minimum overlap between an ORF and a BLAST hit to consider it
    # to establish belongingness

    minimal_overlap = transcript.json_conf[
        "pick"]["chimera_split"]["blast_params"]["minimal_hsp_overlap"]

    cds_hit_dict = SortedDict().fromkeys(cds_boundaries.keys())
    for key in cds_hit_dict:
        cds_hit_dict[key] = collections.defaultdict(list)

    # BUG, this is a hacky fix
    if not hasattr(transcript, "blast_hits"):
        transcript.logger.warning(
            "BLAST hits store lost for %s! Creating a mock one to avoid a crash",

            transcript.id)
        transcript.blast_hits = []

    transcript.logger.debug("%s has %d possible hits", transcript.id, len(transcript.blast_hits))

    # Determine for each CDS which are the hits available
    min_eval = transcript.json_conf["pick"]['chimera_split']['blast_params']['hsp_evalue']
    for hit in transcript.blast_hits:
        for hsp in iter(_hsp for _hsp in hit["hsps"] if
                        _hsp["hsp_evalue"] <= min_eval):
            for cds_run in cds_boundaries:
                # If I have a valid hit b/w the CDS region and the hit,
                # add the name to the set
                overlap_threshold = minimal_overlap * (cds_run[1] + 1 - cds_run[0])
                overl = overlap(cds_run, (hsp['query_hsp_start'], hsp['query_hsp_end']))

                if overl >= overlap_threshold:
                    cds_hit_dict[cds_run][(hit["target"], hit["target_length"])].append(hsp)
                    transcript.logger.debug(
                        "Overlap %s passed for %s between %s CDS and %s HSP (threshold %s)",
                        overlap,
                        transcript.id,
                        cds_run,
                        (hsp['query_hsp_start'], hsp['query_hsp_end']),
                        overlap_threshold)
                else:
                    transcript.logger.debug(
                        "Overlap %s rejected for %s between %s CDS and %s HSP (threshold %s)",
                        overlap,
                        transcript.id,
                        cds_run, (hsp['query_hsp_start'], hsp['query_hsp_end']),
                        overlap_threshold)

    transcript.logger.debug("Final cds_hit_dict for %s: %s", transcript.id, cds_hit_dict)

    final_boundaries = SortedDict()
    for boundary in __get_boundaries_from_blast(transcript, cds_boundaries, cds_hit_dict):
        if len(boundary) == 1:
            assert len(boundary[0]) == 2
            boundary = boundary[0]
            final_boundaries[boundary] = cds_boundaries[boundary]
        else:
            nboun = (boundary[0][0], boundary[-1][1])
            final_boundaries[nboun] = []
            for boun in boundary:
                final_boundaries[nboun].extend(cds_boundaries[boun])
    transcript.logger.debug("Final boundaries for %s: %s",
                            transcript.id, final_boundaries)

    cds_boundaries = final_boundaries.copy()
    return cds_boundaries
Пример #10
0
class RegionMap(object):
    """
    Mostly used in SimAbstractMemory, RegionMap stores a series of mappings between concrete memory address ranges and
    memory regions, like stack frames and heap regions.
    """
    def __init__(self, is_stack):
        """
        Constructor

        :param is_stack:    Whether this is a region map for stack frames or not. Different strategies apply for stack
                            regions.
        """
        self.is_stack = is_stack

        # A sorted list, which maps stack addresses to region IDs
        self._address_to_region_id = SortedDict()
        # A dict, which maps region IDs to memory address ranges
        self._region_id_to_address = {}

    #
    # Properties
    #

    def __repr__(self):
        return "RegionMap<%s>" % ("S" if self.is_stack else "H")

    @property
    def is_empty(self):
        return len(self._address_to_region_id) == 0

    @property
    def stack_base(self):
        if not self.is_stack:
            raise SimRegionMapError(
                'Calling "stack_base" on a non-stack region map.')

        return next(self._address_to_region_id.irange(reverse=True))

    @property
    def region_ids(self):
        return self._region_id_to_address.keys()

    #
    # Public methods
    #

    @SimStatePlugin.memo
    def copy(self, memo):  # pylint: disable=unused-argument
        r = RegionMap(is_stack=self.is_stack)

        # A shallow copy should be enough, since we never modify any RegionDescriptor object in-place
        r._address_to_region_id = self._address_to_region_id.copy()
        r._region_id_to_address = self._region_id_to_address.copy()

        return r

    def map(self, absolute_address, region_id, related_function_address=None):
        """
        Add a mapping between an absolute address and a region ID. If this is a stack region map, all stack regions
        beyond (lower than) this newly added regions will be discarded.

        :param absolute_address:            An absolute memory address.
        :param region_id:                   ID of the memory region.
        :param related_function_address:    A related function address, mostly used for stack regions.
        """

        if self.is_stack:
            # Sanity check
            if not region_id.startswith('stack_'):
                raise SimRegionMapError(
                    'Received a non-stack memory ID "%d" in a stack region map'
                    % region_id)

            # Remove all stack regions that are lower than the one to add
            while True:
                try:
                    addr = next(
                        self._address_to_region_id.irange(
                            maximum=absolute_address, reverse=True))
                    descriptor = self._address_to_region_id[addr]
                    # Remove this mapping
                    del self._address_to_region_id[addr]
                    # Remove this region ID from the other mapping
                    del self._region_id_to_address[descriptor.region_id]
                except StopIteration:
                    break

        else:
            if absolute_address in self._address_to_region_id:
                descriptor = self._address_to_region_id[absolute_address]
                # Remove this mapping
                del self._address_to_region_id[absolute_address]
                del self._region_id_to_address[descriptor.region_id]

        # Add this new region mapping
        desc = RegionDescriptor(
            region_id,
            absolute_address,
            related_function_address=related_function_address)

        self._address_to_region_id[absolute_address] = desc
        self._region_id_to_address[region_id] = desc

    def unmap_by_address(self, absolute_address):
        """
        Removes a mapping based on its absolute address.

        :param absolute_address: An absolute address
        """

        desc = self._address_to_region_id[absolute_address]
        del self._address_to_region_id[absolute_address]
        del self._region_id_to_address[desc.region_id]

    def absolutize(self, region_id, relative_address):
        """
        Convert a relative address in some memory region to an absolute address.

        :param region_id:           The memory region ID
        :param relative_address:    The relative memory offset in that memory region
        :return:                    An absolute address if converted, or an exception is raised when region id does not
                                    exist.
        """

        if region_id == 'global':
            # The global region always bases 0
            return relative_address

        if region_id not in self._region_id_to_address:
            raise SimRegionMapError('Non-existent region ID "%s"' % region_id)

        base_address = self._region_id_to_address[region_id].base_address
        return base_address + relative_address

    def relativize(self, absolute_address, target_region_id=None):
        """
        Convert an absolute address to the memory offset in a memory region.

        Note that if an address belongs to heap region is passed in to a stack region map, it will be converted to an
        offset included in the closest stack frame, and vice versa for passing a stack address to a heap region.
        Therefore you should only pass in address that belongs to the same category (stack or non-stack) of this region
        map.

        :param absolute_address:    An absolute memory address
        :return:                    A tuple of the closest region ID, the relative offset, and the related function
                                    address.
        """

        if target_region_id is None:
            if self.is_stack:
                # Get the base address of the stack frame it belongs to
                base_address = next(
                    self._address_to_region_id.irange(minimum=absolute_address,
                                                      reverse=False))

            else:
                try:
                    base_address = next(
                        self._address_to_region_id.irange(
                            maximum=absolute_address, reverse=True))

                except StopIteration:
                    # Not found. It belongs to the global region then.
                    return 'global', absolute_address, None

            descriptor = self._address_to_region_id[base_address]

        else:
            if target_region_id == 'global':
                # Just return the absolute address
                return 'global', absolute_address, None

            if target_region_id not in self._region_id_to_address:
                raise SimRegionMapError(
                    'Trying to relativize to a non-existent region "%s"' %
                    target_region_id)

            descriptor = self._region_id_to_address[target_region_id]
            base_address = descriptor.base_address

        return descriptor.region_id, absolute_address - base_address, descriptor.related_function_address
Пример #11
0
def check_split_by_blast(transcript, cds_boundaries):
    """
    This method verifies if a transcript with multiple ORFs has support by BLAST to
    NOT split it into its different components.

    The minimal overlap between ORF and HSP is defined inside the JSON at the key
        ["chimera_split"]["blast_params"]["minimal_hsp_overlap"]
    basically, we consider a HSP a hit only if the overlap is over a certain threshold
    and the HSP evalue under a certain threshold.

    The split by CDS can be executed in three different ways - PERMISSIVE, LENIENT, STRINGENT:

    - PERMISSIVE: split if two CDSs do not have hits in common,
    even when one or both do not have a hit at all.
    - STRINGENT: split only if two CDSs have hits and none
    of those is in common between them.
    - LENIENT: split if *both* lack hits, OR *both* have hits and none
    of those is in common.

    :param transcript: the transcript instance
    :type transcript: Mikado.loci_objects.transcript.Transcript
    :param cds_boundaries:
    :return: cds_boundaries
    :rtype: dict
    """

    # Establish the minimum overlap between an ORF and a BLAST hit to consider it
    # to establish belongingness

    minimal_overlap = transcript.json_conf["pick"]["chimera_split"][
        "blast_params"]["minimal_hsp_overlap"]

    cds_hit_dict = SortedDict().fromkeys(cds_boundaries.keys())
    for key in cds_hit_dict:
        cds_hit_dict[key] = collections.defaultdict(list)

    # BUG, this is a hacky fix
    if not hasattr(transcript, "blast_hits"):
        transcript.logger.warning(
            "BLAST hits store lost for %s! Creating a mock one to avoid a crash",
            transcript.id)
        transcript.blast_hits = []

    transcript.logger.debug("%s has %d possible hits", transcript.id,
                            len(transcript.blast_hits))

    # Determine for each CDS which are the hits available
    min_eval = transcript.json_conf["pick"]['chimera_split']['blast_params'][
        'hsp_evalue']
    for hit in transcript.blast_hits:
        for hsp in iter(_hsp for _hsp in hit["hsps"]
                        if _hsp["hsp_evalue"] <= min_eval):
            for cds_run in cds_boundaries:
                # If I have a valid hit b/w the CDS region and the hit,
                # add the name to the set
                overlap_threshold = minimal_overlap * (cds_run[1] + 1 -
                                                       cds_run[0])
                overl = overlap(cds_run,
                                (hsp['query_hsp_start'], hsp['query_hsp_end']))

                if overl >= overlap_threshold:
                    cds_hit_dict[cds_run][(hit["target"],
                                           hit["target_length"])].append(hsp)
                    transcript.logger.debug(
                        "Overlap %s passed for %s between %s CDS and %s HSP (threshold %s)",
                        overlap, transcript.id, cds_run,
                        (hsp['query_hsp_start'], hsp['query_hsp_end']),
                        overlap_threshold)
                else:
                    transcript.logger.debug(
                        "Overlap %s rejected for %s between %s CDS and %s HSP (threshold %s)",
                        overlap, transcript.id, cds_run,
                        (hsp['query_hsp_start'], hsp['query_hsp_end']),
                        overlap_threshold)

    transcript.logger.debug("Final cds_hit_dict for %s: %s", transcript.id,
                            cds_hit_dict)

    final_boundaries = SortedDict()
    for boundary in __get_boundaries_from_blast(transcript, cds_boundaries,
                                                cds_hit_dict):
        if len(boundary) == 1:
            assert len(boundary[0]) == 2
            boundary = boundary[0]
            final_boundaries[boundary] = cds_boundaries[boundary]
        else:
            nboun = (boundary[0][0], boundary[-1][1])
            final_boundaries[nboun] = []
            for boun in boundary:
                final_boundaries[nboun].extend(cds_boundaries[boun])
    transcript.logger.debug("Final boundaries for %s: %s", transcript.id,
                            final_boundaries)

    cds_boundaries = final_boundaries.copy()
    return cds_boundaries
Пример #12
0
class RegionMap(object):
    """
    Mostly used in SimAbstractMemory, RegionMap stores a series of mappings between concrete memory address ranges and
    memory regions, like stack frames and heap regions.
    """

    def __init__(self, is_stack):
        """
        Constructor

        :param is_stack:    Whether this is a region map for stack frames or not. Different strategies apply for stack
                            regions.
        """
        self.is_stack = is_stack

        # A sorted list, which maps stack addresses to region IDs
        self._address_to_region_id = SortedDict()
        # A dict, which maps region IDs to memory address ranges
        self._region_id_to_address = { }

    #
    # Properties
    #

    def __repr__(self):
        return "RegionMap<%s>" % (
            "S" if self.is_stack else "H"
        )

    @property
    def is_empty(self):
        return len(self._address_to_region_id) == 0

    @property
    def stack_base(self):
        if not self.is_stack:
            raise SimRegionMapError('Calling "stack_base" on a non-stack region map.')

        return next(self._address_to_region_id.irange(reverse=True))

    @property
    def region_ids(self):
        return self._region_id_to_address.keys()

    #
    # Public methods
    #

    @SimStatePlugin.memo
    def copy(self, memo): # pylint: disable=unused-argument
        r = RegionMap(is_stack=self.is_stack)

        # A shallow copy should be enough, since we never modify any RegionDescriptor object in-place
        r._address_to_region_id = self._address_to_region_id.copy()
        r._region_id_to_address = self._region_id_to_address.copy()

        return r

    def map(self, absolute_address, region_id, related_function_address=None):
        """
        Add a mapping between an absolute address and a region ID. If this is a stack region map, all stack regions
        beyond (lower than) this newly added regions will be discarded.

        :param absolute_address:            An absolute memory address.
        :param region_id:                   ID of the memory region.
        :param related_function_address:    A related function address, mostly used for stack regions.
        """

        if self.is_stack:
            # Sanity check
            if not region_id.startswith('stack_'):
                raise SimRegionMapError('Received a non-stack memory ID "%d" in a stack region map' % region_id)

            # Remove all stack regions that are lower than the one to add
            while True:
                try:
                    addr = next(self._address_to_region_id.irange(maximum=absolute_address, reverse=True))
                    descriptor = self._address_to_region_id[addr]
                    # Remove this mapping
                    del self._address_to_region_id[addr]
                    # Remove this region ID from the other mapping
                    del self._region_id_to_address[descriptor.region_id]
                except StopIteration:
                    break

        else:
            if absolute_address in self._address_to_region_id:
                descriptor = self._address_to_region_id[absolute_address]
                # Remove this mapping
                del self._address_to_region_id[absolute_address]
                del self._region_id_to_address[descriptor.region_id]

        # Add this new region mapping
        desc = RegionDescriptor(
            region_id,
            absolute_address,
            related_function_address=related_function_address
        )

        self._address_to_region_id[absolute_address] = desc
        self._region_id_to_address[region_id] = desc

    def unmap_by_address(self, absolute_address):
        """
        Removes a mapping based on its absolute address.

        :param absolute_address: An absolute address
        """

        desc = self._address_to_region_id[absolute_address]
        del self._address_to_region_id[absolute_address]
        del self._region_id_to_address[desc.region_id]

    def absolutize(self, region_id, relative_address):
        """
        Convert a relative address in some memory region to an absolute address.

        :param region_id:           The memory region ID
        :param relative_address:    The relative memory offset in that memory region
        :return:                    An absolute address if converted, or an exception is raised when region id does not
                                    exist.
        """

        if region_id == 'global':
            # The global region always bases 0
            return relative_address

        if region_id not in self._region_id_to_address:
            raise SimRegionMapError('Non-existent region ID "%s"' % region_id)

        base_address = self._region_id_to_address[region_id].base_address
        return base_address + relative_address

    def relativize(self, absolute_address, target_region_id=None):
        """
        Convert an absolute address to the memory offset in a memory region.

        Note that if an address belongs to heap region is passed in to a stack region map, it will be converted to an
        offset included in the closest stack frame, and vice versa for passing a stack address to a heap region.
        Therefore you should only pass in address that belongs to the same category (stack or non-stack) of this region
        map.

        :param absolute_address:    An absolute memory address
        :return:                    A tuple of the closest region ID, the relative offset, and the related function
                                    address.
        """

        if target_region_id is None:
            if self.is_stack:
                # Get the base address of the stack frame it belongs to
                base_address = next(self._address_to_region_id.irange(minimum=absolute_address, reverse=False))

            else:
                try:
                    base_address = next(self._address_to_region_id.irange(maximum=absolute_address, reverse=True))

                except StopIteration:
                    # Not found. It belongs to the global region then.
                    return 'global', absolute_address, None

            descriptor = self._address_to_region_id[base_address]

        else:
            if target_region_id == 'global':
                # Just return the absolute address
                return 'global', absolute_address, None

            if target_region_id not in self._region_id_to_address:
                raise SimRegionMapError('Trying to relativize to a non-existent region "%s"' % target_region_id)

            descriptor = self._region_id_to_address[target_region_id]
            base_address = descriptor.base_address

        return descriptor.region_id, absolute_address - base_address, descriptor.related_function_address
Пример #13
0
class PrecisionRecallCurve:
    """
	Represents a curve relating a chosen detection threshold to precision and recall.  Internally, this is actually
	stored as a sorted list of detection events, which are used to compute metrics on the fly when needed.
	"""

    # TODO(mdsavage): make this accept matching strategies other than bounding box IOU

    events: SortedDict[float, _DetectionEvent]
    ground_truth_positives: int

    def __init__(self,
                 events: Optional[SortedDict[float, _DetectionEvent]] = None,
                 ground_truth_positives: int = 0):
        self.events = SortedDict() if events is None else events
        self.ground_truth_positives = ground_truth_positives

    def clone(self) -> PrecisionRecallCurve:
        return PrecisionRecallCurve(self.events.copy(),
                                    self.ground_truth_positives)

    def maximize_f1(self) -> MaximizeF1Result:
        maximum = MaximizeF1Result(threshold=1, precision=0, recall=0, f1=0)

        for threshold, precision, recall in self._compute_curve():
            f1 = 2 / ((1 / precision) +
                      (1 / recall)) if precision > 0 and recall > 0 else 0
            if f1 >= maximum.f1:
                maximum = MaximizeF1Result(threshold=threshold,
                                           precision=precision,
                                           recall=recall,
                                           f1=f1)

        return maximum

    def plot(self) -> plt.Figure:
        import matplotlib.pyplot as plt
        fig = plt.figure()
        curve = self._compute_curve()
        plt.plot([pt.recall for pt in curve], [pt.precision for pt in curve],
                 "o-")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        return fig

    def add_annotation(self: PrecisionRecallCurve,
                       ground_truth: ImageAnnotation,
                       prediction: ImageAnnotation,
                       iou_threshold: float) -> None:
        """
		Returns a precision-recall curve for the given ground truth and prediction annotations evaluated with the given
		IOU threshold.

		Note: this handles instances only; multi-instances are ignored.
		"""
        ground_truth_boxes = [
            GroundTruthBox(class_name, instance.bounding_box.rectangle)
            for class_name in ground_truth.classes.keys()
            for instance in ground_truth.classes[class_name].instances
            if instance.bounding_box is not None
        ]

        prediction_boxes = sorted([
            PredictionBox(instance.bounding_box.confidence or 1, class_name,
                          instance.bounding_box.rectangle)
            for class_name in prediction.classes.keys()
            for instance in prediction.classes[class_name].instances
            if instance.bounding_box is not None
        ],
                                  reverse=True,
                                  key=lambda p: p.confidence)

        iou_matrix = np.array([[
            ground_truth_box.box.iou(prediction_box.box)
            for ground_truth_box in ground_truth_boxes
        ] for prediction_box in prediction_boxes])

        self._add_ground_truth_positives(len(ground_truth_boxes))

        previous_true_positives = 0
        previous_false_positives = 0

        for i in range(len(prediction_boxes)):
            confidence_threshold = prediction_boxes[i].confidence

            if i < len(prediction_boxes) - 1 and prediction_boxes[
                    i + 1].confidence == confidence_threshold:
                continue

            prediction_indices, ground_truth_indices = linear_sum_assignment(
                iou_matrix[:i + 1, ], maximize=True)

            true_positives = 0
            false_positives = max(0, i + 1 - len(ground_truth_boxes))

            for prediction_index, ground_truth_index in zip(
                    cast(Iterable[int], prediction_indices),
                    cast(Iterable[int], ground_truth_indices)):
                if (iou_matrix[prediction_index,
                               ground_truth_index] >= iou_threshold
                        and prediction_boxes[prediction_index].class_name
                        == ground_truth_boxes[ground_truth_index].class_name):
                    true_positives += 1
                else:
                    false_positives += 1

            self._add_event(
                confidence_threshold,
                _DetectionEvent(true_positive_delta=true_positives -
                                previous_true_positives,
                                false_positive_delta=false_positives -
                                previous_false_positives))

            previous_true_positives = true_positives
            previous_false_positives = false_positives

    def batch_add_annotation(self: PrecisionRecallCurve,
                             ground_truths: Sequence[ImageAnnotation],
                             predictions: Sequence[ImageAnnotation],
                             iou_threshold: float) -> None:
        """
		Updates this precision-recall curve with the values from several annotations simultaneously.
		"""
        for ground_truth, prediction in zip(ground_truths, predictions):
            self.add_annotation(ground_truth, prediction, iou_threshold)

    def _compute_curve(self) -> List[_PrecisionRecallPoint]:
        assert self.ground_truth_positives > 0
        precision_recall_points: List[_PrecisionRecallPoint] = []

        true_positives = 0
        detections = 0

        for threshold in reversed(self.events):
            true_positive_delta, false_positive_delta = self.events[threshold]
            true_positives += true_positive_delta
            detections += true_positive_delta + false_positive_delta
            assert detections > 0

            precision_recall_points.append(
                _PrecisionRecallPoint(threshold=threshold,
                                      precision=true_positives / detections,
                                      recall=true_positives /
                                      self.ground_truth_positives))

        return precision_recall_points

    def _add_event(self, threshold: float, event: _DetectionEvent) -> None:
        if threshold not in self.events:
            self.events[threshold] = _DetectionEvent(0, 0)
        self.events[threshold] += event

    def _add_ground_truth_positives(self, count: int) -> None:
        self.ground_truth_positives += count

    def __add__(self, other: PrecisionRecallCurve) -> PrecisionRecallCurve:
        if isinstance(
                other, PrecisionRecallCurve
        ):  # type: ignore - pyright complains about the isinstance check being redundant
            ret = self.clone()
            ret._add_ground_truth_positives(other.ground_truth_positives)

            for threshold, event in other.events.items():
                ret._add_event(threshold, event)

            return ret
        return NotImplemented