def test_interval_removal_72():
    tree = IntervalTree([
        Interval(0.0, 2.588, 841),
        Interval(65.5, 85.8, 844),
        Interval(93.6, 130.0, 837),
        Interval(125.0, 196.5, 829),
        Interval(391.8, 521.0, 825),
        Interval(720.0, 726.0, 834),
        Interval(800.0, 1033.0, 850),
        Interval(800.0, 1033.0, 855),
    ])
    tree.verify()
    tree.remove_overlap(0.0, 521.0)
    tree.verify()
Esempio n. 2
0
def select_top_overlapping_intervals(begin,end,closed_end=False):
    """Given a collection of overlapping intervals, select one "best" for each overlapping subset.
    For example, intervals can be BLASTN hits from reference DB to the contig.

    Intervals should be provided as begin and end iterables, **sorted in decreasing** order
    by "interval strength" that will define the preference when selecting
    among overlapping intervals.

    result - this is a generator that yields indices of the selected records in the input.

    How selection is currently done: if A,B and C are intervals,
    shown here is decreasing order of preference, and B overlaps both
    A and C, but A does not overlap C, then both A and C will be returned
    as selections. Moving in the initial order of intervals, the algorithm
    discards from future consideration all intervals that overlap with the
    currently considered intervals, stores the index of the current interval,
    and moves to the next remaining interval.
    """
    ##If other variations of pruning the overlapping intervals are needed,
    ##we can use any of the Python graph libraries such as APGL, build a
    ##graph from overlaps and use methods like findConnectedComponents() etc.
    from intervaltree import Interval, IntervalTree
    intervals = [Interval(*iv) for iv in zip(begin,
                                                 end if not closed_end else ((_+1) for _ in end))]
    tree = IntervalTree(intervals)
    ##TODO: what is the cost of removing tree nodes? Maybe it is cheaper
    ##to tag nodes overlapping the current one is a separate array.
    for iiv,iv in enumerate(intervals):
        ## docs says tree membership check for Interval object is O(1)
        if iv in tree:
            yield iiv
            try:
                tree.remove_overlap(iv.begin,iv.end)
            except KeyError as msg:
                log.warning("KeyError when removing existing node from IntervalTree: {}. \
                        This must be a bug to fix in IntervalTree code. Ignoring this error for now.".\
                        format(msg))
Esempio n. 3
0
class MemoryCache(object):
    def __init__(self, context):
        self._context = context
        self._run_token = -1
        self._log = logging.getLogger('memcache')
        self._reset_cache()

    def _reset_cache(self):
        self._cache = IntervalTree()
        self._metrics = CacheMetrics()

    ##
    # @brief Invalidates the cache if appropriate.
    def _check_cache(self):
        if self._context.core.is_running():
            self._log.debug("core is running; invalidating cache")
            self._reset_cache()
        elif self._run_token != self._context.core.run_token:
            self._dump_metrics()
            self._log.debug("out of date run token; invalidating cache")
            self._reset_cache()
            self._run_token = self._context.core.run_token

    ##
    # @brief Splits a memory address range into cached and uncached subranges.
    # @return Returns a 2-tuple with the first element being a set of Interval objects for each
    #   of the cached subranges. The second element is a set of Interval objects for each of the
    #   non-cached subranges.
    def _get_ranges(self, addr, count):
        cached = self._cache.overlap(addr, addr + count)
        uncached = {Interval(addr, addr + count)}
        for cachedIv in cached:
            newUncachedSet = set()
            for uncachedIv in uncached:

                # No overlap.
                if cachedIv.end < uncachedIv.begin or cachedIv.begin > uncachedIv.end:
                    newUncachedSet.add(uncachedIv)
                    continue

                # Begin segment.
                if cachedIv.begin - uncachedIv.begin > 0:
                    newUncachedSet.add(
                        Interval(uncachedIv.begin, cachedIv.begin))

                # End segment.
                if uncachedIv.end - cachedIv.end > 0:
                    newUncachedSet.add(Interval(cachedIv.end, uncachedIv.end))
            uncached = newUncachedSet
        return cached, uncached

    ##
    # @brief Reads uncached memory ranges and updates the cache.
    # @return A list of Interval objects is returned. Each Interval has its @a data attribute set
    #   to a bytearray of the data read from target memory.
    def _read_uncached(self, uncached):
        uncachedData = []
        for uncachedIv in uncached:
            data = self._context.read_memory_block8(
                uncachedIv.begin, uncachedIv.end - uncachedIv.begin)
            iv = Interval(uncachedIv.begin, uncachedIv.end, bytearray(data))
            self._cache.add(iv)  # TODO merge contiguous cached intervals
            uncachedData.append(iv)
        return uncachedData

    def _update_metrics(self, cached, uncached, addr, size):
        cachedSize = 0
        for iv in cached:
            begin = iv.begin
            end = iv.end
            if iv.begin < addr:
                begin = addr
            if iv.end > addr + size:
                end = addr + size
            cachedSize += end - begin

        uncachedSize = sum((iv.end - iv.begin) for iv in uncached)

        self._metrics.reads += 1
        self._metrics.hits += cachedSize
        self._metrics.misses += uncachedSize

    def _dump_metrics(self):
        if self._metrics.total > 0:
            self._log.debug(
                "%d reads, %d bytes [%d%% hits, %d bytes]; %d bytes written",
                self._metrics.reads, self._metrics.total,
                self._metrics.percent_hit, self._metrics.hits,
                self._metrics.writes)
        else:
            self._log.debug("no reads")

    ##
    # @brief Performs a cached read operation of an address range.
    # @return A list of Interval objects sorted by address.
    def _read(self, addr, size):
        # Get the cached and uncached subranges of the requested read.
        cached, uncached = self._get_ranges(addr, size)
        self._update_metrics(cached, uncached, addr, size)

        # Read any uncached ranges.
        uncachedData = self._read_uncached(uncached)

        # Merged cached with data we just read
        combined = list(cached) + uncachedData
        combined.sort(key=lambda x: x.begin)
        return combined

    ##
    # @brief Extracts data from the intersection of an address range across a list of interval objects.
    #
    # The range represented by @a addr and @a size are assumed to overlap the intervals. The first
    # and last interval in the list may have ragged edges not fully contained in the address range, in
    # which case the correct slice of those intervals is extracted.
    #
    # @param self
    # @param combined List of Interval objects forming a contiguous range. The @a data attribute of
    #   each interval must be a bytearray.
    # @param addr Start address. Must be within the range of the first interval.
    # @param size Number of bytes. (@a addr + @a size) must be within the range of the last interval.
    # @return A single bytearray object with all data from the intervals that intersects the address
    #   range.
    def _merge_data(self, combined, addr, size):
        result = bytearray()
        resultAppend = bytearray()

        # Check for fully contained subrange.
        if len(combined) and combined[0].begin < addr and combined[
                0].end > addr + size:
            offset = addr - combined[0].begin
            endOffset = offset + size
            result = combined[0].data[offset:endOffset]
            return result

        # Take slice of leading ragged edge.
        if len(combined) and combined[0].begin < addr:
            offset = addr - combined[0].begin
            result += combined[0].data[offset:]
            combined = combined[1:]
        # Take slice of trailing ragged edge.
        if len(combined) and combined[-1].end > addr + size:
            offset = addr + size - combined[-1].begin
            resultAppend = combined[-1].data[:offset]
            combined = combined[:-1]

        # Merge.
        for iv in combined:
            result += iv.data
        result += resultAppend

        return result

    ##
    # @brief
    def _update_contiguous(self, cached, addr, value):
        size = len(value)
        end = addr + size
        leadBegin = addr
        leadData = bytearray()
        trailData = bytearray()
        trailEnd = end

        if cached[0].begin < addr and cached[0].end > addr:
            offset = addr - cached[0].begin
            leadData = cached[0].data[:offset]
            leadBegin = cached[0].begin
        if cached[-1].begin < end and cached[-1].end > end:
            offset = end - cached[-1].begin
            trailData = cached[-1].data[offset:]
            trailEnd = cached[-1].end

        self._cache.remove_overlap(addr, end)

        data = leadData + value + trailData
        self._cache.addi(leadBegin, trailEnd, data)

    ##
    # @return A bool indicating whether the given address range is fully contained within
    #       one known memory region, and that region is cacheable.
    # @exception MemoryAccessError Raised if the access is not entirely contained within a single region.
    def _check_regions(self, addr, count):
        regions = self._context.core.memory_map.get_intersecting_regions(
            addr, length=count)

        # If no regions matched, then allow an uncached operation.
        if len(regions) == 0:
            return False

        # Raise if not fully contained within one region.
        if len(regions) > 1 or not regions[0].contains_range(addr,
                                                             length=count):
            raise MemoryAccessError(
                "individual memory accesses must not cross memory region boundaries"
            )

        # Otherwise return whether the region is cacheable.
        return regions[0].is_cacheable

    def read_memory(self, addr, transfer_size=32, now=True):
        # TODO use more optimal underlying read_memory call
        if transfer_size == 8:
            data = self.read_memory_block8(addr, 1)[0]
        elif transfer_size == 16:
            data = conversion.byte_list_to_u16le_list(
                self.read_memory_block8(addr, 2))[0]
        elif transfer_size == 32:
            data = conversion.byte_list_to_u32le_list(
                self.read_memory_block8(addr, 4))[0]

        if now:
            return data
        else:

            def read_cb():
                return data

            return read_cb

    def read_memory_block8(self, addr, size):
        if size <= 0:
            return []

        self._check_cache()

        # Validate memory regions.
        if not self._check_regions(addr, size):
            self._log.debug("range [%x:%x] is not cacheable", addr,
                            addr + size)
            return self._context.read_memory_block8(addr, size)

        # Get the cached and uncached subranges of the requested read.
        combined = self._read(addr, size)

        # Extract data out of combined intervals.
        result = list(self._merge_data(combined, addr, size))
        assert len(
            result) == size, "result size ({}) != requested size ({})".format(
                len(result), size)
        return result

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(
            self.read_memory_block8(addr, size * 4))

    def write_memory(self, addr, value, transfer_size=32):
        if transfer_size == 8:
            return self.write_memory_block8(addr, [value])
        elif transfer_size == 16:
            return self.write_memory_block8(
                addr, conversion.u16le_list_to_byte_list([value]))
        elif transfer_size == 32:
            return self.write_memory_block8(
                addr, conversion.u32le_list_to_byte_list([value]))

    def write_memory_block8(self, addr, value):
        if len(value) <= 0:
            return

        self._check_cache()

        # Validate memory regions.
        cacheable = self._check_regions(addr, len(value))

        # Write to the target first, so if it fails we don't update the cache.
        result = self._context.write_memory_block8(addr, value)

        if cacheable:
            size = len(value)
            end = addr + size
            cached = sorted(self._cache.overlap(addr, end),
                            key=lambda x: x.begin)
            self._metrics.writes += size

            if len(cached):
                # Write data is entirely within a single cached interval.
                if addr >= cached[0].begin and end <= cached[0].end:
                    beginOffset = addr - cached[0].begin
                    endOffset = beginOffset + size
                    cached[0].data[beginOffset:endOffset] = value

                else:
                    self._update_contiguous(cached, addr, bytearray(value))
            else:
                # No cached data in this range, so just add the entire interval.
                self._cache.addi(addr, end, bytearray(value))

        return result

    def write_memory_block32(self, addr, data):
        return self.write_memory_block8(
            addr, conversion.u32le_list_to_byte_list(data))

    def invalidate(self):
        self._reset_cache()
Esempio n. 4
0
class MainClient:

    def __init__(self, zkquorum, pool_size):
        # Location of the ZooKeeper quorum (csv)
        self.zkquorum = zkquorum
        # Connection pool size per region server (and master!)
        self.pool_size = pool_size
        # Persistent connection to the master server.
        self.master_client = None
        # IntervalTree data structure that allows me to create ranges
        # representing known row keys that fall within a specific region. Any
        # 'region look up' is then O(logn)
        self.region_cache = IntervalTree()
        # Takes a client's host:port as key and maps it to a client instance.
        self.reverse_client_cache = {}
        # Mutex used for all caching operations.
        self._cache_lock = Lock()
        # Mutex used so only one thread can request meta information from
        # the master at a time.
        self._master_lookup_lock = Lock()

    """
        HERE LAY CACHE OPERATIONS
    """

    def _add_to_region_cache(self, new_region):
        stop_key = new_region.stop_key
        if stop_key == '':
            # This is hacky but our interval tree requires hard interval stops.
            # So what's the largest char out there? chr(255) -> '\xff'. If
            # you're using '\xff' as a prefix for your rows then this'll cause
            # a cache miss on every request.
            stop_key = '\xff'
        # Keys are formatted like: 'tablename,key'
        start_key = new_region.table + ',' + new_region.start_key
        stop_key = new_region.table + ',' + stop_key

        # Only let one person touch the cache at once.
        with self._cache_lock:
            # Get all overlapping regions (overlapping == stale)
            overlapping_regions = self.region_cache[start_key:stop_key]
            # Close the overlapping regions.
            self._close_old_regions(overlapping_regions)
            # Remove the overlapping regions.
            self.region_cache.remove_overlap(start_key, stop_key)
            # Insert my region.
            self.region_cache[start_key:stop_key] = new_region
            # Add this region to the region_client's internal
            # list of all the regions it serves.
            new_region.region_client.regions.append(new_region)

    def _get_from_region_cache(self, table, key):
        # Only let one person touch the cache at once.
        with self._cache_lock:
            # We don't care about the last two characters ',:' in the meta_key.
            # 'table,key,:' --> 'table,key'
            meta_key = self._construct_meta_key(table, key)[:-2]
            # Fetch the region that serves this key
            regions = self.region_cache[meta_key]
            try:
                # Returns a set. Pop the element from the set.
                # (there shouldn't be more than 1 elem in the set)
                a = regions.pop()
                return a.data
            except KeyError:
                # Returned set is empty? Cache miss!
                return None

    def _delete_from_region_cache(self, table, start_key):
        # Don't acquire the lock because the calling function should have done
        # so already
        self.region_cache.remove_overlap(table + "," + start_key)

    """
        HERE LAY REQUESTS
    """

    def get(self, table, key, families={}, filters=None):
        """
        get a row or specified cell with optional filter
        :param table: hbase table
        :param key: row key
        :param families: (optional) specifies columns to get,
          e.g., {"columnFamily1":["col1","col2"], "colFamily2": "col3"}
        :param filters: (optional) column filters
        :return: response with cells
        """
        try:
            # Step 0. Set dest_region to None so if an exception is
            # thrown in _find_hosting_region, the exception handling
            # doesn't break trying to reference dest_region.
            dest_region = None
            # Step 1. Figure out where to send it.
            dest_region = self._find_hosting_region(table, key)
            # Step 2. Build the appropriate pb message.
            rq = request.get_request(dest_region, key, families, filters)
            # Step 3. Send the message and twiddle our thumbs.
            response = dest_region.region_client._send_request(rq)
            # Step 4. Success.
            return Result(response)
        except PyBaseException as e:
            # Step X. Houston, we have an error. The cool thing about how
            # this is coded is that exceptions know how to handle themselves.
            # All we need to do is call _handle_exception and everything should
            # be happy! If it cannot handle itself (unrecoverable) then it will
            # re-raise the exception in the handle method and we'll die too.
            #
            # We pass dest_region in because the handling code needs to know
            # which region or region_client it needs to reestablish.
            e._handle_exception(self, dest_region=dest_region)
            # Everything should be dandy now. Repeat the request!
            return self.get(table, key, families=families, filters=filters)

    def put(self, table, key, values):
        return self._mutate(table, key, values, request.put_request)

    def delete(self, table, key, values):
        return self._mutate(table, key, values, request.delete_request)

    def append(self, table, key, values):
        return self._mutate(table, key, values, request.append_request)

    def increment(self, table, key, values):
        return self._mutate(table, key, values, request.increment_request)

    def _mutate(self, table, key, values, rq_type):
        # Same exact methodology as 'get'. Because all mutate requests have
        # equivalent code I've combined them into a single function.
        try:
            dest_region = None
            dest_region = self._find_hosting_region(table, key)
            rq = rq_type(dest_region, key, values)
            response = dest_region.region_client._send_request(rq)
            return Result(response)
        except PyBaseException as e:
            e._handle_exception(self, dest_region=dest_region)
            return self._mutate(table, key, values, rq_type)

    # Scan can get a bit gnarly - be prepared.
    def scan(self, table, start_key='', stop_key=None, families={}, filters=None):
        # We convert the filter immediately such that it doesn't have to be done
        # for every region. However if the filter has already been converted then
        # we can't convert it again. This means that even though we send out N RPCs
        # we only have to package the filter pb type once.
        if filters is not None and type(filters).__name__ != "Filter":
            filters = _to_filter(filters)
        previous_stop_key = start_key
        # Holds the contents of all responses. We return this at the end.
        result_set = Result(None)
        # We're going to need to loop over every relevant region. Break out
        # of this loop once we discover there are no more regions left to scan.
        while True:
            # Finds the first region and sends the initial message to it.
            first_response, cur_region = self._scan_hit_region_once(
                previous_stop_key, table, start_key, stop_key, families, filters)
            try:
                # Now we need to keep pinging this region for more results until
                # it has no more results to return. We can change how many rows it
                # returns for each call in the Requests module but I picked a
                # pseudo-arbitrary figure (alright, fine, I stole it from
                # asynchbase)
                #
                # We pass in first_response so it can pull out the scanner_id
                # from the first response.
                second_response = self._scan_region_while_more_results(
                    cur_region, first_response)
            except PyBaseException as e:
                # Something happened to the region/region client in the middle
                # of a scan. We're going to handle it by...
                #
                # Handle the exception.
                e._handle_exception(self, dest_region=cur_region)
                # Recursively scan JUST this range of keys in the region (it could have been split
                # or merged so this recursive call may be scanning multiple regions or only half
                # of one region).
                result_set._append_response(self.scan(
                    table, start_key=previous_stop_key, stop_key=cur_region.stop_key, families=families, filters=filters))
                # We continue here because we don't want to append the
                # first_response results to the result_set. When we did the
                # recursive scan it rescanned whatever the first_response
                # initially contained. Appending both will produce duplicates.
                previous_stop_key = cur_region.stop_key
                if previous_stop_key == '' or (stop_key is not None and previous_stop_key > stop_key):
                    break
                continue
            # Both calls succeeded! Append the results to the result_set.
            result_set._append_response(first_response)
            result_set._append_response(second_response)
            # Update the new previous_stop_key (so the next iteration can
            # lookup the next region to scan)
            previous_stop_key = cur_region.stop_key
            # Stopping criteria. This region is either the end ('') or the end of this region is
            # beyond the specific stop_key.
            if previous_stop_key == '' or (stop_key is not None and previous_stop_key > stop_key):
                break
        return result_set

    def _scan_hit_region_once(self, previous_stop_key, table, start_key, stop_key, families, filters):
        try:
            # Lookup the next region to scan by searching for the
            # previous_stop_key (region keys are inclusive on the start and
            # exclusive on the end)
            cur_region = self._find_hosting_region(
                table, previous_stop_key)
        except PyBaseException as e:
            # This means that either Master is down or something's funky with the META region. Try handling it
            # and recursively perform the same call again.
            e._handle_exception(self)
            return self._scan_hit_region_once(previous_stop_key, table, start_key, stop_key, families, filters)
        # Create the scan request object. The last two values are 'Close' and
        # 'Scanner_ID' respectively.
        rq = request.scan_request(
            cur_region, start_key, stop_key, families, filters, False, None)
        try:
            # Send the request.
            response = cur_region.region_client._send_request(rq)
        except PyBaseException as e:
            # Uh oh. Probably a region/region server issue. Handle it and try
            # again.
            e._handle_exception(self, dest_region=cur_region)
            return self._scan_hit_region_once(previous_stop_key, table, start_key, stop_key, families, filters)
        return response, cur_region

    def _scan_region_while_more_results(self, cur_region, response):
        # Create our own intermediate response set.
        response_set = Result(None)
        # Grab the scanner_id from the first_response.
        scanner_id = response.scanner_id
        # We only need to specify the scanner_id here because the region we're
        # pinging remembers our query based on the scanner_id.
        rq = request.scan_request(
            cur_region, None, None, None, None, False, scanner_id)
        while response.more_results_in_region:
            # Repeatedly hit it until empty. Note that we're not handling any
            # exceptions here, instead letting them bubble up because if any
            # of these calls fail we need to rescan the whole region (it seems
            # like a lot of work to search the results for the max row key that
            # we've received so far and rescan from there up)
            response = cur_region.region_client._send_request(rq)
            response_set._append_response(response)
        # Now close the scanner.
        rq = request.scan_request(
            cur_region, None, None, None, None, True, scanner_id)
        _ = cur_region.region_client._send_request(rq)
        # Close it and return the results!
        return response_set

    """
        HERE LAY REGION AND CLIENT DISCOVERY
    """

    def _find_hosting_region(self, table, key):
        # Check if it's in the cache already.
        dest_region = self._get_from_region_cache(table, key)
        if dest_region is None:
            # We have to reach out to master for the results.
            with self._master_lookup_lock:
                # Not ideal that we have to lock every thread however we limit
                # concurrent meta requests to one. This is because of the case
                # where 1000 greenlets all fail simultaneously we don't want
                # 1000 requests shot off to the master (all looking for the
                # same response). My solution is to only let one through at a
                # time and then when it's your turn, check the cache again to
                # see if one of the greenlets let in before you already fetched
                # the meta or not. We can't bucket greenlets and selectively
                # wake them up simply because we have no idea which key falls
                # into which region. We can bucket based on key but that's a
                # lot of overhead for an unlikely scenario.
                dest_region = self._get_from_region_cache(table, key)
                if dest_region is None:
                    # Nope, still not in the cache.
                    logger.debug(
                        'Region cache miss! Table: %s, Key: %s', table, key)
                    # Ask master for region information.
                    dest_region = self._discover_region(table, key)
        return dest_region

    def _discover_region(self, table, key):
        meta_key = self._construct_meta_key(table, key)
        # Create the appropriate meta request given a meta_key.
        meta_rq = request.master_request(meta_key)
        try:
            # This will throw standard Region/RegionServer exceptions.
            # We need to catch them and convert them to the Master equivalent.
            response = self.master_client._send_request(meta_rq)
        except (AttributeError, RegionServerException, RegionException):
            if self.master_client is None:
                # I don't know why this can happen but it does.
                raise MasterServerException(None, None)
            raise MasterServerException(
                self.master_client.host, self.master_client.port)
        # Master gave us a response. We need to run and parse the response,
        # then do all necessary work for entering it into our structures.
        return self._create_new_region(response, table)

    def _create_new_region(self, response, table):
        cells = response.result.cell
        # We have a valid response but no cells? Apparently that means the
        # table doesn't exist!
        if len(cells) == 0:
            raise NoSuchTableException("Table does not exist.")
        # We get ~4 cells back each holding different information. We only care
        # about two of them.
        for cell in cells:
            if cell.qualifier == "regioninfo":
                # Take the regioninfo information and parse it into our own
                # Region representation.
                new_region = region_from_cell(cell)
            elif cell.qualifier == "server":
                # Grab the host, port of the Region Server that this region is
                # hosted on.
                server_loc = cell.value
                host, port = cell.value.split(':')
            else:
                continue
        # Do we have an existing client for this region server already?
        if server_loc in self.reverse_client_cache:
            # If so, grab it!
            new_region.region_client = self.reverse_client_cache[server_loc]
        else:
            # Otherwise we need to create a new region client instance.
            new_client = region.NewClient(host, port, self.pool_size)
            if new_client is None:
                # Welp. We can't connect to the server that the Master
                # supplied. Raise an exception.
                raise RegionServerException(host=host, port=port)
            logger.info("Created new Client for RegionServer %s", server_loc)
            # Add it to the host,port -> instance of region client map.
            self.reverse_client_cache[server_loc] = new_client
            # Attach the region_client to the region.
            new_region.region_client = new_client
        # Region's set up! Add this puppy to the cache so future requests can
        # use it.
        self._add_to_region_cache(new_region)
        logger.info("Successfully discovered new region %s", new_region)
        return new_region

    def _recreate_master_client(self):
        if self.master_client is not None:
            # yep, still no idea why self.master_client can be set to None.
            self.master_client.close()
        # Ask ZooKeeper for the location of the Master.
        ip, port = zk.LocateMaster(self.zkquorum)
        try:
            # Try creating a new client instance and setting it as the new
            # master_client.
            self.master_client = region.NewClient(ip, port, self.pool_size)
        except RegionServerException:
            # We can't connect to the address that ZK supplied. Raise an
            # exception.
            raise MasterServerException(ip, port)

    """
        HERE LAY THE MISCELLANEOUS
    """

    def _close_old_regions(self, overlapping_region_intervals):
        # Loop over the regions to close and close whoever their
        # attached client is.
        #
        # TODO: ...should we really be killing a client unneccessarily?
        for reg in overlapping_region_intervals:
            reg.data.region_client.close()

    def _purge_client(self, region_client):
        # Given a client to close, purge all of it's known hosted regions from
        # our cache, delete the reverse lookup entry and close the client
        # clearing up any file descriptors.
        with self._cache_lock:
            for reg in region_client.regions:
                self._delete_from_region_cache(reg.table, reg.start_key)
            self.reverse_client_cache.pop(
                region_client.host + ":" + region_client.port, None)
            region_client.close()

    def _purge_region(self, reg):
        # Given a region, deletes it's entry from the cache and removes itself
        # from it's region client's region list.
        with self._cache_lock:
            self._delete_from_region_cache(reg.table, reg.start_key)
            try:
                reg.region_client.regions.remove(reg)
            except ValueError:
                pass

    def _construct_meta_key(self, table, key):
        return table + "," + key + ",:"

    def close(self):
        logger.info("Main client received close request.")
        # Close the master client.
        if self.master_client is not None:
            self.master_client.close()
        # Clear the region cache.
        self.region_cache.clear()
        # Close each open region client.
        for location, client in self.reverse_client_cache.items():
            client.close()
        self.reverse_client_cache = {}
Esempio n. 5
0
    def _optimize(self):
        ''' 
            Do a few things:
            1. Remove redundant bugs first.
            2. Remove everything that isn't related to a bug. 
        '''
        bugs = [x for x in self.trace if self._is_bug(x)]

        is_flush = lambda x: x['event'] == 'FLUSH'
        is_store = lambda x: x['event'] == 'STORE'
        is_fence = lambda x: x['event'] == 'FENCE'
        get_timestamp = lambda x: x['timestamp']

        # Step 1: Remove bugs from redundant locations
        new_trace = self._opt_remove_redt_bugs()

        print(f'(Step 1) Optimized from {len(self.trace)} trace events to {len(new_trace)} trace events.')
        self.trace = new_trace

        # Step 2: Remove irrelevant stores.
        '''
            First, we get the addresses of all the reported bugs.
            Then, for stores which don't match the address of a bug, remove
            the store.
        '''
        unique_bug_addrs = IntervalTree()
        new_trace = []
        for te in self.trace:
            if not self._is_bug(te):
                continue

            a1, a2 = self._get_bug_addresses(te)

            if a1 is not None:
                unique_bug_addrs.addi(a1[0], a1[1], True)
            
            if a2 is not None:
                unique_bug_addrs.addi(a2[0], a2[1], True)

        # Now we have the bug addresses. We now remove stores and flushes
        #  unrelated to those.
        # We will remove the ranges as we reverse through the list.
        new_trace = []
        for te in reversed(self.trace):
            if te['event'] != 'STORE' and te['event'] != 'FLUSH':
                new_trace += [te]
                continue

            addr = (te['address'], te['address'] + te['length'])
            
            if unique_bug_addrs.overlap(*addr):
                new_trace += [te]
                # unique_bug_addrs.remove_overlap(*addr)

        new_trace.reverse()

        print(f'(Step 2) Optimized from {len(self.trace)} trace events to {len(new_trace)} trace events.')
        self.trace = new_trace


        '''
        Now, we want to remove all repeated stores between flushes. Essentially,
        if store X, store X, ... flush X, we only want the most recent store X.
        '''

        in_flight = IntervalTree()
        new_trace = []
        for te in self.trace:
            if not is_flush(te) and not is_store(te):
                new_trace += [te]
                continue

            addr = (te['address'], te['address'] + te['length'])

            if is_flush(te):
                # Get all the stores in the range, remove them, and add them to the
                # new trace
                tes = in_flight[addr[0]:addr[1]]
                if tes:
                    # Make them trace events again
                    new_trace += [x.data for x in tes]
                    in_flight.remove_overlap(addr[0], addr[1])
            
            if is_store(te):
                # Add to the range, overwriting anything before.
                in_flight.addi(addr[0], addr[1], te)
                # embed()

        # Now, I need to add all the things back that were never flushed
        new_trace += [x.data for x in in_flight[:]]

        # Sort the new_trace by timestamp
        # embed()
        new_trace.sort(key=get_timestamp)

        print(f'(Step 3) Optimized from {len(self.trace)} trace events to {len(new_trace)} trace events.')
        self.trace = new_trace

        # Step: Remove redundant fences
        new_trace = []
        prev_te = None
        # for te in reversed(self.trace):
        for te in self.trace:
            if te in bugs:
                new_trace += [te]
                continue
            
            if te['event'] == 'FENCE':
                if prev_te is not None and prev_te['event'] == 'FENCE':
                    continue
        
            new_trace += [te]
            prev_te = te

        print(f'(Step 4) Optimized from {len(self.trace)} trace events to {len(new_trace)} trace events.')
        self.trace = new_trace
Esempio n. 6
0
class MainClient:
    def __init__(self, zkquorum, pool_size):
        # Location of the ZooKeeper quorum (csv)
        self.zkquorum = zkquorum
        # Connection pool size per region server (and master!)
        self.pool_size = pool_size
        # Persistent connection to the master server.
        self.master_client = None
        # IntervalTree data structure that allows me to create ranges
        # representing known row keys that fall within a specific region. Any
        # 'region look up' is then O(logn)
        self.region_cache = IntervalTree()
        # Takes a client's host:port as key and maps it to a client instance.
        self.reverse_client_cache = {}
        # Mutex used for all caching operations.
        self._cache_lock = Lock()
        # Mutex used so only one thread can request meta information from
        # the master at a time.
        self._master_lookup_lock = Lock()

    """
        HERE LAY CACHE OPERATIONS
    """

    def _add_to_region_cache(self, new_region):
        stop_key = new_region.stop_key
        if stop_key == '':
            # This is hacky but our interval tree requires hard interval stops.
            # So what's the largest char out there? chr(255) -> '\xff'. If
            # you're using '\xff' as a prefix for your rows then this'll cause
            # a cache miss on every request.
            stop_key = '\xff'
        # Keys are formatted like: 'tablename,key'
        start_key = new_region.table + ',' + new_region.start_key
        stop_key = new_region.table + ',' + stop_key

        # Only let one person touch the cache at once.
        with self._cache_lock:
            # Get all overlapping regions (overlapping == stale)
            overlapping_regions = self.region_cache[start_key:stop_key]
            # Close the overlapping regions.
            self._close_old_regions(overlapping_regions)
            # Remove the overlapping regions.
            self.region_cache.remove_overlap(start_key, stop_key)
            # Insert my region.
            self.region_cache[start_key:stop_key] = new_region
            # Add this region to the region_client's internal
            # list of all the regions it serves.
            new_region.region_client.regions.append(new_region)

    def _get_from_region_cache(self, table, key):
        # Only let one person touch the cache at once.
        with self._cache_lock:
            # We don't care about the last two characters ',:' in the meta_key.
            # 'table,key,:' --> 'table,key'
            meta_key = self._construct_meta_key(table, key)[:-2]
            # Fetch the region that serves this key
            regions = self.region_cache[meta_key]
            try:
                # Returns a set. Pop the element from the set.
                # (there shouldn't be more than 1 elem in the set)
                a = regions.pop()
                return a.data
            except KeyError:
                # Returned set is empty? Cache miss!
                return None

    def _delete_from_region_cache(self, table, start_key):
        # Don't acquire the lock because the calling function should have done
        # so already
        self.region_cache.remove_overlap(table + "," + start_key)

    """
        HERE LAY REQUESTS
    """

    def get(self, table, key, families={}, filters=None):
        """
        get a row or specified cell with optional filter
        :param table: hbase table
        :param key: row key
        :param families: (optional) specifies columns to get,
          e.g., {"columnFamily1":["col1","col2"], "colFamily2": "col3"}
        :param filters: (optional) column filters
        :return: response with cells
        """
        try:
            # Step 0. Set dest_region to None so if an exception is
            # thrown in _find_hosting_region, the exception handling
            # doesn't break trying to reference dest_region.
            dest_region = None
            # Step 1. Figure out where to send it.
            dest_region = self._find_hosting_region(table, key)
            # Step 2. Build the appropriate pb message.
            rq = request.get_request(dest_region, key, families, filters)
            # Step 3. Send the message and twiddle our thumbs.
            response = dest_region.region_client._send_request(rq)
            # Step 4. Success.
            return Result(response)
        except PyBaseException as e:
            # Step X. Houston, we have an error. The cool thing about how
            # this is coded is that exceptions know how to handle themselves.
            # All we need to do is call _handle_exception and everything should
            # be happy! If it cannot handle itself (unrecoverable) then it will
            # re-raise the exception in the handle method and we'll die too.
            #
            # We pass dest_region in because the handling code needs to know
            # which region or region_client it needs to reestablish.
            e._handle_exception(self, dest_region=dest_region)
            # Everything should be dandy now. Repeat the request!
            return self.get(table, key, families=families, filters=filters)

    def put(self, table, key, values):
        return self._mutate(table, key, values, request.put_request)

    def delete(self, table, key, values):
        return self._mutate(table, key, values, request.delete_request)

    def append(self, table, key, values):
        return self._mutate(table, key, values, request.append_request)

    def increment(self, table, key, values):
        return self._mutate(table, key, values, request.increment_request)

    def _mutate(self, table, key, values, rq_type):
        # Same exact methodology as 'get'. Because all mutate requests have
        # equivalent code I've combined them into a single function.
        try:
            dest_region = None
            dest_region = self._find_hosting_region(table, key)
            rq = rq_type(dest_region, key, values)
            response = dest_region.region_client._send_request(rq)
            return Result(response)
        except PyBaseException as e:
            e._handle_exception(self, dest_region=dest_region)
            return self._mutate(table, key, values, rq_type)

    # Scan can get a bit gnarly - be prepared.
    def scan(self,
             table,
             start_key='',
             stop_key=None,
             families={},
             filters=None):
        # We convert the filter immediately such that it doesn't have to be done
        # for every region. However if the filter has already been converted then
        # we can't convert it again. This means that even though we send out N RPCs
        # we only have to package the filter pb type once.
        if filters is not None and type(filters).__name__ != "Filter":
            filters = _to_filter(filters)
        previous_stop_key = start_key
        # Holds the contents of all responses. We return this at the end.
        result_set = Result(None)
        # We're going to need to loop over every relevant region. Break out
        # of this loop once we discover there are no more regions left to scan.
        while True:
            # Finds the first region and sends the initial message to it.
            first_response, cur_region = self._scan_hit_region_once(
                previous_stop_key, table, start_key, stop_key, families,
                filters)
            try:
                # Now we need to keep pinging this region for more results until
                # it has no more results to return. We can change how many rows it
                # returns for each call in the Requests module but I picked a
                # pseudo-arbitrary figure (alright, fine, I stole it from
                # asynchbase)
                #
                # We pass in first_response so it can pull out the scanner_id
                # from the first response.
                second_response = self._scan_region_while_more_results(
                    cur_region, first_response)
            except PyBaseException as e:
                # Something happened to the region/region client in the middle
                # of a scan. We're going to handle it by...
                #
                # Handle the exception.
                e._handle_exception(self, dest_region=cur_region)
                # Recursively scan JUST this range of keys in the region (it could have been split
                # or merged so this recursive call may be scanning multiple regions or only half
                # of one region).
                result_set._append_response(
                    self.scan(table,
                              start_key=previous_stop_key,
                              stop_key=cur_region.stop_key,
                              families=families,
                              filters=filters))
                # We continue here because we don't want to append the
                # first_response results to the result_set. When we did the
                # recursive scan it rescanned whatever the first_response
                # initially contained. Appending both will produce duplicates.
                previous_stop_key = cur_region.stop_key
                if previous_stop_key == '' or (stop_key is not None and
                                               previous_stop_key > stop_key):
                    break
                continue
            # Both calls succeeded! Append the results to the result_set.
            result_set._append_response(first_response)
            result_set._append_response(second_response)
            # Update the new previous_stop_key (so the next iteration can
            # lookup the next region to scan)
            previous_stop_key = cur_region.stop_key
            # Stopping criteria. This region is either the end ('') or the end of this region is
            # beyond the specific stop_key.
            if previous_stop_key == '' or (stop_key is not None
                                           and previous_stop_key > stop_key):
                break
        return result_set

    def _scan_hit_region_once(self, previous_stop_key, table, start_key,
                              stop_key, families, filters):
        try:
            # Lookup the next region to scan by searching for the
            # previous_stop_key (region keys are inclusive on the start and
            # exclusive on the end)
            cur_region = self._find_hosting_region(table, previous_stop_key)
        except PyBaseException as e:
            # This means that either Master is down or something's funky with the META region. Try handling it
            # and recursively perform the same call again.
            e._handle_exception(self)
            return self._scan_hit_region_once(previous_stop_key, table,
                                              start_key, stop_key, families,
                                              filters)
        # Create the scan request object. The last two values are 'Close' and
        # 'Scanner_ID' respectively.
        rq = request.scan_request(cur_region, start_key, stop_key, families,
                                  filters, False, None)
        try:
            # Send the request.
            response = cur_region.region_client._send_request(rq)
        except PyBaseException as e:
            # Uh oh. Probably a region/region server issue. Handle it and try
            # again.
            e._handle_exception(self, dest_region=cur_region)
            return self._scan_hit_region_once(previous_stop_key, table,
                                              start_key, stop_key, families,
                                              filters)
        return response, cur_region

    def _scan_region_while_more_results(self, cur_region, response):
        # Create our own intermediate response set.
        response_set = Result(None)
        # Grab the scanner_id from the first_response.
        scanner_id = response.scanner_id
        # We only need to specify the scanner_id here because the region we're
        # pinging remembers our query based on the scanner_id.
        rq = request.scan_request(cur_region, None, None, None, None, False,
                                  scanner_id)
        while response.more_results_in_region:
            # Repeatedly hit it until empty. Note that we're not handling any
            # exceptions here, instead letting them bubble up because if any
            # of these calls fail we need to rescan the whole region (it seems
            # like a lot of work to search the results for the max row key that
            # we've received so far and rescan from there up)
            response = cur_region.region_client._send_request(rq)
            response_set._append_response(response)
        # Now close the scanner.
        rq = request.scan_request(cur_region, None, None, None, None, True,
                                  scanner_id)
        _ = cur_region.region_client._send_request(rq)
        # Close it and return the results!
        return response_set

    """
        HERE LAY REGION AND CLIENT DISCOVERY
    """

    def _find_hosting_region(self, table, key):
        # Check if it's in the cache already.
        dest_region = self._get_from_region_cache(table, key)
        if dest_region is None:
            # We have to reach out to master for the results.
            with self._master_lookup_lock:
                # Not ideal that we have to lock every thread however we limit
                # concurrent meta requests to one. This is because of the case
                # where 1000 greenlets all fail simultaneously we don't want
                # 1000 requests shot off to the master (all looking for the
                # same response). My solution is to only let one through at a
                # time and then when it's your turn, check the cache again to
                # see if one of the greenlets let in before you already fetched
                # the meta or not. We can't bucket greenlets and selectively
                # wake them up simply because we have no idea which key falls
                # into which region. We can bucket based on key but that's a
                # lot of overhead for an unlikely scenario.
                dest_region = self._get_from_region_cache(table, key)
                if dest_region is None:
                    # Nope, still not in the cache.
                    logger.debug('Region cache miss! Table: %s, Key: %s',
                                 table, key)
                    # Ask master for region information.
                    dest_region = self._discover_region(table, key)
        return dest_region

    def _discover_region(self, table, key):
        meta_key = self._construct_meta_key(table, key)
        # Create the appropriate meta request given a meta_key.
        meta_rq = request.master_request(meta_key)
        try:
            # This will throw standard Region/RegionServer exceptions.
            # We need to catch them and convert them to the Master equivalent.
            response = self.master_client._send_request(meta_rq)
        except (AttributeError, RegionServerException, RegionException):
            if self.master_client is None:
                # I don't know why this can happen but it does.
                raise MasterServerException(None, None)
            raise MasterServerException(self.master_client.host,
                                        self.master_client.port)
        # Master gave us a response. We need to run and parse the response,
        # then do all necessary work for entering it into our structures.
        return self._create_new_region(response, table)

    def _create_new_region(self, response, table):
        cells = response.result.cell
        # We have a valid response but no cells? Apparently that means the
        # table doesn't exist!
        if len(cells) == 0:
            raise NoSuchTableException("Table does not exist.")
        # We get ~4 cells back each holding different information. We only care
        # about two of them.
        for cell in cells:
            if cell.qualifier == "regioninfo":
                # Take the regioninfo information and parse it into our own
                # Region representation.
                new_region = region_from_cell(cell)
            elif cell.qualifier == "server":
                # Grab the host, port of the Region Server that this region is
                # hosted on.
                server_loc = cell.value
                host, port = cell.value.split(':')
            else:
                continue
        # Do we have an existing client for this region server already?
        if server_loc in self.reverse_client_cache:
            # If so, grab it!
            new_region.region_client = self.reverse_client_cache[server_loc]
        else:
            # Otherwise we need to create a new region client instance.
            new_client = region.NewClient(host, port, self.pool_size)
            if new_client is None:
                # Welp. We can't connect to the server that the Master
                # supplied. Raise an exception.
                raise RegionServerException(host=host, port=port)
            logger.info("Created new Client for RegionServer %s", server_loc)
            # Add it to the host,port -> instance of region client map.
            self.reverse_client_cache[server_loc] = new_client
            # Attach the region_client to the region.
            new_region.region_client = new_client
        # Region's set up! Add this puppy to the cache so future requests can
        # use it.
        self._add_to_region_cache(new_region)
        logger.info("Successfully discovered new region %s", new_region)
        return new_region

    def _recreate_master_client(self):
        if self.master_client is not None:
            # yep, still no idea why self.master_client can be set to None.
            self.master_client.close()
        # Ask ZooKeeper for the location of the Master.
        ip, port = zk.LocateMaster(self.zkquorum)
        try:
            # Try creating a new client instance and setting it as the new
            # master_client.
            self.master_client = region.NewClient(ip, port, self.pool_size)
        except RegionServerException:
            # We can't connect to the address that ZK supplied. Raise an
            # exception.
            raise MasterServerException(ip, port)

    """
        HERE LAY THE MISCELLANEOUS
    """

    def _close_old_regions(self, overlapping_region_intervals):
        # Loop over the regions to close and close whoever their
        # attached client is.
        #
        # TODO: ...should we really be killing a client unneccessarily?
        for reg in overlapping_region_intervals:
            reg.data.region_client.close()

    def _purge_client(self, region_client):
        # Given a client to close, purge all of it's known hosted regions from
        # our cache, delete the reverse lookup entry and close the client
        # clearing up any file descriptors.
        with self._cache_lock:
            for reg in region_client.regions:
                self._delete_from_region_cache(reg.table, reg.start_key)
            self.reverse_client_cache.pop(
                region_client.host + ":" + region_client.port, None)
            region_client.close()

    def _purge_region(self, reg):
        # Given a region, deletes it's entry from the cache and removes itself
        # from it's region client's region list.
        with self._cache_lock:
            self._delete_from_region_cache(reg.table, reg.start_key)
            try:
                reg.region_client.regions.remove(reg)
            except ValueError:
                pass

    def _construct_meta_key(self, table, key):
        return table + "," + key + ",:"

    def close(self):
        logger.info("Main client received close request.")
        # Close the master client.
        if self.master_client is not None:
            self.master_client.close()
        # Clear the region cache.
        self.region_cache.clear()
        # Close each open region client.
        for location, client in self.reverse_client_cache.items():
            client.close()
        self.reverse_client_cache = {}
Esempio n. 7
0

with open(snakemake.input.loci_info) as instream:
    ivtree = load_loci_info(instream)

logging.info(f"Loaded {len(ivtree)} loci")

with open(snakemake.input.mask) as instream:
    mask = load_mask(instream)

logging.info(f"Loaded {len(mask)} mask regions")

masked_tree = IntervalTree(ivtree)

for iv in mask:
    masked_tree.remove_overlap(iv.begin, iv.end)

full_len = 0
for iv in ivtree:
    full_len += iv.length()

masked_len = 0
for iv in masked_tree:
    masked_len += iv.length()

logging.info(
    f"{len(ivtree)-len(masked_tree)} ({1-(len(masked_tree)/len(ivtree)):.2%}) loci removed"
)
logging.info(
    f"{full_len-masked_len}bp ({1-(masked_len/full_len):.2%}) of the genome removed"
)
def test_all():
    from intervaltree import Interval, IntervalTree
    from pprint import pprint
    from operator import attrgetter
    
    def makeinterval(lst):
        return Interval(
            lst[0], 
            lst[1], 
            "{}-{}".format(*lst)
            )
    
    ivs = list(map(makeinterval, [
        [1,2],
        [4,7],
        [5,9],
        [6,10],
        [8,10],
        [8,15],
        [10,12],
        [12,14],
        [14,15],
        ]))
    t = IntervalTree(ivs)
    t.verify()
    
    def data(s): 
        return set(map(attrgetter('data'), s))
    
    # Query tests
    print('Query tests...')
    assert data(t[4])          == set(['4-7'])
    assert data(t[4:5])        == set(['4-7'])
    assert data(t[4:6])        == set(['4-7', '5-9'])
    assert data(t[9])          == set(['6-10', '8-10', '8-15'])
    assert data(t[15])         == set()
    assert data(t.search(5))   == set(['4-7', '5-9'])
    assert data(t.search(6, 11, strict = True)) == set(['6-10', '8-10'])
    
    print('    passed')
    
    # Membership tests
    print('Membership tests...')
    assert ivs[1] in t
    assert Interval(1,3, '1-3') not in t
    assert t.overlaps(4)
    assert t.overlaps(9)
    assert not t.overlaps(15)
    assert t.overlaps(0,4)
    assert t.overlaps(1,2)
    assert t.overlaps(1,3)
    assert t.overlaps(8,15)
    assert not t.overlaps(15, 16)
    assert not t.overlaps(-1, 0)
    assert not t.overlaps(2,4)
    print('    passed')
    
    # Insertion tests
    print('Insertion tests...')
    t.add( makeinterval([1,2]) )  # adding duplicate should do nothing
    assert data(t[1])        == set(['1-2'])
    
    t[1:2] = '1-2'                # adding duplicate should do nothing
    assert data(t[1])        == set(['1-2'])
    
    t.add(makeinterval([2,4]))
    assert data(t[2])        == set(['2-4'])
    t.verify()
    
    t[13:15] = '13-15'
    assert data(t[14])       == set(['8-15', '13-15', '14-15'])
    t.verify()
    print('    passed')
    
    # Duplication tests
    print('Interval duplication tests...')
    t.add(Interval(14,15,'14-15####'))
    assert data(t[14])        == set(['8-15', '13-15', '14-15', '14-15####'])
    t.verify()
    print('    passed')
    
    # Copying and casting
    print('Tree copying and casting...')
    tcopy = IntervalTree(t)
    tcopy.verify()
    assert t == tcopy
    
    tlist = list(t)
    for iv in tlist:
        assert iv in t
    for iv in t:
        assert iv in tlist
    
    tset = set(t)
    assert tset == t.items()
    print('    passed')
    
    # Deletion tests
    print('Deletion tests...')
    try:
        t.remove(
            Interval(1,3, "Doesn't exist")
            )
    except ValueError:
        pass
    else:
        raise AssertionError("Expected ValueError")
    
    try:
        t.remove(
            Interval(500, 1000, "Doesn't exist")
            )
    except ValueError:
        pass
    else:
        raise AssertionError("Expected ValueError")
    
    orig = t.print_structure(True)
    t.discard( Interval(1,3, "Doesn't exist") )
    t.discard( Interval(500, 1000, "Doesn't exist") )
    
    assert data(t[14])        == set(['8-15', '13-15', '14-15', '14-15####'])
    t.remove( Interval(14,15,'14-15####') )
    assert data(t[14])        == set(['8-15', '13-15', '14-15'])
    t.verify()
    
    assert data(t[2])        == set(['2-4'])
    t.discard( makeinterval([2,4]) )
    assert data(t[2])        == set()
    t.verify()
    
    assert t[14]
    t.remove_overlap(14)
    t.verify()
    assert not t[14]
    
    # Emptying the tree
    #t.print_structure()
    for iv in sorted(iter(t)):
        #print('### Removing '+str(iv)+'... ###')
        t.remove(iv)
        #t.print_structure()
        t.verify()
        #print('')
    assert len(t) == 0
    assert t.is_empty()
    assert not t
    
    t = IntervalTree(ivs)
    #t.print_structure()
    t.remove_overlap(1)
    #t.print_structure()
    t.verify()
    
    t.remove_overlap(8)
    #t.print_structure()    
    print('    passed')
    
    t = IntervalTree(ivs)
    pprint(t)
    t.split_overlaps()
    pprint(t)
    #import cPickle as pickle
    #p = pickle.dumps(t)
    #print(p)
    
Esempio n. 9
0
File: cache.py Progetto: flit/pyOCD
class MemoryCache(object):
    """! @brief Memory cache.
    
    Maintains a cache of target memory. The constructor is passed a backing DebugContext object that
    will be used to fill the cache.
    
    The cache is invalidated whenever the target has run since the last cache operation (based on run
    tokens). If the target is currently running, all accesses cause the cache to be invalidated.
    
    The target's memory map is referenced. All memory accesses must be fully contained within a single
    memory region, or a MemoryAccessError will be raised. However, if an access is outside of all regions,
    the access is passed to the underlying context unmodified. When an access is within a region, that
    region's cacheability flag is honoured.
    """
    
    def __init__(self, context, core):
        self._context = context
        self._core = core
        self._run_token = -1
        self._log = LOG.getChild('memcache')
        self._reset_cache()

    def _reset_cache(self):
        self._cache = IntervalTree()
        self._metrics = CacheMetrics()

    def _check_cache(self):
        """! @brief Invalidates the cache if appropriate."""
        if self._core.is_running():
            self._log.debug("core is running; invalidating cache")
            self._reset_cache()
        elif self._run_token != self._core.run_token:
            self._dump_metrics()
            self._log.debug("out of date run token; invalidating cache")
            self._reset_cache()
            self._run_token = self._core.run_token

    def _get_ranges(self, addr, count):
        """! @brief Splits a memory address range into cached and uncached subranges.
        @return Returns a 2-tuple with the first element being a set of Interval objects for each
          of the cached subranges. The second element is a set of Interval objects for each of the
          non-cached subranges.
        """
        cached = self._cache.overlap(addr, addr + count)
        uncached = {Interval(addr, addr + count)}
        for cachedIv in cached:
            newUncachedSet = set()
            for uncachedIv in uncached:

                # No overlap.
                if cachedIv.end < uncachedIv.begin or cachedIv.begin > uncachedIv.end:
                    newUncachedSet.add(uncachedIv)
                    continue

                # Begin segment.
                if cachedIv.begin - uncachedIv.begin > 0:
                    newUncachedSet.add(Interval(uncachedIv.begin, cachedIv.begin))

                # End segment.
                if uncachedIv.end - cachedIv.end > 0:
                    newUncachedSet.add(Interval(cachedIv.end, uncachedIv.end))
            uncached = newUncachedSet
        return cached, uncached

    def _read_uncached(self, uncached):
        """! "@brief Reads uncached memory ranges and updates the cache.
        @return A list of Interval objects is returned. Each Interval has its @a data attribute set
          to a bytearray of the data read from target memory.
        """
        uncachedData = []
        for uncachedIv in uncached:
            data = self._context.read_memory_block8(uncachedIv.begin, uncachedIv.end - uncachedIv.begin)
            iv = Interval(uncachedIv.begin, uncachedIv.end, bytearray(data))
            self._cache.add(iv) # TODO merge contiguous cached intervals
            uncachedData.append(iv)
        return uncachedData

    def _update_metrics(self, cached, uncached, addr, size):
        cachedSize = 0
        for iv in cached:
            begin = iv.begin
            end = iv.end
            if iv.begin < addr:
                begin = addr
            if iv.end > addr + size:
                end = addr + size
            cachedSize += end - begin

        uncachedSize = sum((iv.end - iv.begin) for iv in uncached)

        self._metrics.reads += 1
        self._metrics.hits += cachedSize
        self._metrics.misses += uncachedSize

    def _dump_metrics(self):
        if self._metrics.total > 0:
            self._log.debug("%d reads, %d bytes [%d%% hits, %d bytes]; %d bytes written",
                self._metrics.reads, self._metrics.total, self._metrics.percent_hit,
                self._metrics.hits, self._metrics.writes)
        else:
            self._log.debug("no reads")

    def _read(self, addr, size):
        """! @brief Performs a cached read operation of an address range.
        @return A list of Interval objects sorted by address.
        """
        # Get the cached and uncached subranges of the requested read.
        cached, uncached = self._get_ranges(addr, size)
        self._update_metrics(cached, uncached, addr, size)

        # Read any uncached ranges.
        uncachedData = self._read_uncached(uncached)

        # Merged cached with data we just read
        combined = list(cached) + uncachedData
        combined.sort(key=lambda x: x.begin)
        return combined

    def _merge_data(self, combined, addr, size):
        """! @brief Extracts data from the intersection of an address range across a list of interval objects.
        
        The range represented by @a addr and @a size are assumed to overlap the intervals. The first
        and last interval in the list may have ragged edges not fully contained in the address range, in
        which case the correct slice of those intervals is extracted.
        
        @param self
        @param combined List of Interval objects forming a contiguous range. The @a data attribute of
          each interval must be a bytearray.
        @param addr Start address. Must be within the range of the first interval.
        @param size Number of bytes. (@a addr + @a size) must be within the range of the last interval.
        @return A single bytearray object with all data from the intervals that intersects the address
          range.
        """
        result = bytearray()
        resultAppend = bytearray()

        # Check for fully contained subrange.
        if len(combined) and combined[0].begin < addr and combined[0].end > addr + size:
            offset = addr - combined[0].begin
            endOffset = offset + size
            result = combined[0].data[offset:endOffset]
            return result
        
        # Take slice of leading ragged edge.
        if len(combined) and combined[0].begin < addr:
            offset = addr - combined[0].begin
            result += combined[0].data[offset:]
            combined = combined[1:]
        # Take slice of trailing ragged edge.
        if len(combined) and combined[-1].end > addr + size:
            offset = addr + size - combined[-1].begin
            resultAppend = combined[-1].data[:offset]
            combined = combined[:-1]

        # Merge.
        for iv in combined:
            result += iv.data
        result += resultAppend

        return result

    def _update_contiguous(self, cached, addr, value):
        size = len(value)
        end = addr + size
        leadBegin = addr
        leadData = bytearray()
        trailData = bytearray()
        trailEnd = end

        if cached[0].begin < addr and cached[0].end > addr:
            offset = addr - cached[0].begin
            leadData = cached[0].data[:offset]
            leadBegin = cached[0].begin
        if cached[-1].begin < end and cached[-1].end > end:
            offset = end - cached[-1].begin
            trailData = cached[-1].data[offset:]
            trailEnd = cached[-1].end

        self._cache.remove_overlap(addr, end)

        data = leadData + value + trailData
        self._cache.addi(leadBegin, trailEnd, data)

    def _check_regions(self, addr, count):
        """! @return A bool indicating whether the given address range is fully contained within
              one known memory region, and that region is cacheable.
        @exception MemoryAccessError Raised if the access is not entirely contained within a single region.
        """
        regions = self._core.memory_map.get_intersecting_regions(addr, length=count)

        # If no regions matched, then allow an uncached operation.
        if len(regions) == 0:
            return False

        # Raise if not fully contained within one region.
        if len(regions) > 1 or not regions[0].contains_range(addr, length=count):
            raise MemoryAccessError("individual memory accesses must not cross memory region boundaries")

        # Otherwise return whether the region is cacheable.
        return regions[0].is_cacheable

    def read_memory(self, addr, transfer_size=32, now=True):
        # TODO use more optimal underlying read_memory call
        if transfer_size == 8:
            data = self.read_memory_block8(addr, 1)[0]
        elif transfer_size == 16:
            data = conversion.byte_list_to_u16le_list(self.read_memory_block8(addr, 2))[0]
        elif transfer_size == 32:
            data = conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, 4))[0]

        if now:
            return data
        else:
            def read_cb():
                return data
            return read_cb

    def read_memory_block8(self, addr, size):
        if size <= 0:
            return []

        self._check_cache()

        # Validate memory regions.
        if not self._check_regions(addr, size):
            self._log.debug("range [%x:%x] is not cacheable", addr, addr+size)
            return self._context.read_memory_block8(addr, size)

        # Get the cached and uncached subranges of the requested read.
        combined = self._read(addr, size)

        # Extract data out of combined intervals.
        result = list(self._merge_data(combined, addr, size))
        assert len(result) == size, "result size ({}) != requested size ({})".format(len(result), size)
        return result

    def read_memory_block32(self, addr, size):
        return conversion.byte_list_to_u32le_list(self.read_memory_block8(addr, size*4))

    def write_memory(self, addr, value, transfer_size=32):
        if transfer_size == 8:
            return self.write_memory_block8(addr, [value])
        elif transfer_size == 16:
            return self.write_memory_block8(addr, conversion.u16le_list_to_byte_list([value]))
        elif transfer_size == 32:
            return self.write_memory_block8(addr, conversion.u32le_list_to_byte_list([value]))

    def write_memory_block8(self, addr, value):
        if len(value) <= 0:
            return

        self._check_cache()

        # Validate memory regions.
        cacheable = self._check_regions(addr, len(value))

        # Write to the target first, so if it fails we don't update the cache.
        result = self._context.write_memory_block8(addr, value)

        if cacheable:
            size = len(value)
            end = addr + size
            cached = sorted(self._cache.overlap(addr, end), key=lambda x:x.begin)
            self._metrics.writes += size

            if len(cached):
                # Write data is entirely within a single cached interval.
                if addr >= cached[0].begin and end <= cached[0].end:
                    beginOffset = addr - cached[0].begin
                    endOffset = beginOffset + size
                    cached[0].data[beginOffset:endOffset] = value

                else:
                    self._update_contiguous(cached, addr, bytearray(value))
            else:
                # No cached data in this range, so just add the entire interval.
                self._cache.addi(addr, end, bytearray(value))

        return result

    def write_memory_block32(self, addr, data):
        return self.write_memory_block8(addr, conversion.u32le_list_to_byte_list(data))

    def invalidate(self):
        self._reset_cache()