コード例 #1
0
def ceil_left(arr_in):
    n_arr = [-1] * len(arr_in)
    nums = SortedSet()
    for i in range(len(arr_in)):
        if arr_in[i] in nums:
            n_arr[i] = arr_in[i]
        else:
            if nums.bisect_left(arr_in[i]) < len(nums):
                n_arr[i] = nums[nums.bisect_left(arr_in[i])]
            nums.add(arr_in[i])
    return n_arr
コード例 #2
0
    def avoidFlood(self, rains: List[int]) -> List[int]:
        n = len(rains)
        d = {}
        res = [-1] * n
        ss = SortedSet()
        for i, r in enumerate(rains):
            if r == 0:
                ss.add(i)
                continue

            if r not in d:
                d[r] = i
            else:
                j = d[r]
                ind = ss.bisect_left(j)
                if ind == len(ss):
                    return []
                res[ss[ind]] = r
                ss.remove(ss[ind])
                d[r] = i

        while ss:
            res[ss.pop()] = 1

        return res
コード例 #3
0
    def busiestServers(self, k: int, arrival: List[int],
                       load: List[int]) -> List[int]:
        cnts = [0] * k
        idle = SortedSet([i for i in range(k)])
        busy = []
        heapq.heapify(busy)

        n, most = len(arrival), 0
        for i in range(n):
            while busy and busy[0][0] <= arrival[i]:
                _, sr = heapq.heappop(busy)
                idle.add(sr)

            if not idle:
                continue

            idx = idle.bisect_left(i % k)
            if idx == len(idle):
                idx = 0
            sr = idle[idx]
            cnts[sr] += 1
            most = max(most, cnts[sr])
            heapq.heappush(busy, (arrival[i] + load[i], sr))
            idle.discard(sr)

        ans = []
        for i, cnt in enumerate(cnts):
            if cnt == most:
                ans.append(i)

        return ans
コード例 #4
0
    def containsNearbyAlmostDuplicate(self, nums: List[int], k: int,
                                      t: int) -> bool:
        if not nums or len(nums) == 0:
            return False

        tree_set = SortedSet()
        for i in range(len(nums)):

            ceilling_idx = tree_set.bisect_left(
                nums[i])  # O(log(k)) – approximate.
            flooring_idx = ceilling_idx - 1

            if ceilling_idx < len(
                    tree_set) and tree_set[ceilling_idx] - nums[i] <= t:
                # ceilling_idx == len(tree_set) -> nums[i] is the largest number in TreeSet
                return True

            if flooring_idx >= 0 and nums[i] - tree_set[flooring_idx] <= t:
                # flooring_idx == 0 -> nums[i] is the smallest number in TreeSet
                return True

            tree_set.add(nums[i])  # O(log(k)) – approximate.
            if i >= k:  # restrict the size of TreeSet to be `k` (window size)
                tree_set.remove(nums[i - k])  # O(log(k)) – approximate.

        return False
コード例 #5
0
class TwoSum:
    """Stores a set of integers and returns the pairs of values that sum to a 
    specified total.
    
    """
    def __init__(self, startList=None):
        """Initialize set of integers.
        
        """
        self._intSet = SortedSet()
        if startList is not None: self.addIntegerList(startList)

    def addIntegerList(self, intList):
        """Add integers from list to set
        
        """
        for i in intList:
            self.addInteger(i)

    def addInteger(self, integer):
        """Add a single integer to set
        
        """
        self._intSet.add(integer)

    def findSums(self, totalLow, totalHigh, requireUnique=False):
        """Return set of totals in range for which at least one pair of
        integers in set existins such that x + y = total.
        
        Optional parameter for uniqueness will exclude x + x = integer.
        """
        # Flip total high and low if passed out of order
        if totalLow > totalHigh:
            swap = totalLow
            totalLow = totalHigh
            totalHigh = swap

        # Initalize empty found totals set
        foundTotals = set()

        # Iterate over each integer in the set
        for x in self._intSet:
            # Determine search range for y
            # NOTE: start index is inclusive where stop index is exclusive
            # to match behavior of SortedSet.islice()
            iStart = self._intSet.bisect_left(totalLow - x)
            iStop = self._intSet.bisect_right(totalHigh - x)

            # Continue if search range invalid
            if iStart >= iStop: continue

            # Add each sum x + y to found totals
            for y in self._intSet.islice(iStart, iStop):
                foundTotals.add(x + y)

        # Return found totals set
        return foundTotals
コード例 #6
0
def find_floor_ceiling(s: SortedSet, num: int):
    ''' floor: maximum value in s which is less than or equal to num 
        ceiling: minimum value in s which is larger than or equal to num '''
    if num in s:
        return num, num
    else:
        idx = s.bisect_left(num)
        floor = s[idx - 1] if idx - 1 >= 0 else None
        ceiling = s[idx] if idx < len(s) else None
    return floor, ceiling
コード例 #7
0
 def containsNearbyAlmostDuplicate(self, nums: List[int], k: int,
                                   t: int) -> bool:
     s = SortedSet()
     for i, num in enumerate(nums):
         idx = s.bisect_left(num - t)
         if 0 <= idx < len(s) and s[idx] <= num + t:
             return True
         s.add(num)
         if i >= k:
             s.remove(nums[i - k])
     return False
コード例 #8
0
    def containsNearbyAlmostDuplicate(self, nums: List[int], k: int, t: int) -> bool:
        sorted_set = SortedSet()

        for i in range(len(nums)):
            num = nums[i]
            print(sorted_set)
            # find the successor of current element
            if sorted_set and sorted_set.bisect_left(num) < len(sorted_set):
                if sorted_set[sorted_set.bisect_left(num)] <= num + t:
                    return True

            # find the predecessor of current element
            if sorted_set and sorted_set.bisect_left(num) != 0:
                if num <= sorted_set[sorted_set.bisect_left(num) - 1] + t:
                    return True

            sorted_set.add(num)
            if len(sorted_set) > k:
                sorted_set.remove(nums[i - k])

        return False
コード例 #9
0
    def containsNearbyAlmostDuplicate(self, nums: List[int], k: int,
                                      t: int) -> bool:
        win = SortedSet()
        for i in range(len(nums)):
            pos = win.bisect_left(nums[i] - t)
            if pos < len(win) and win[pos] <= nums[i] + t:
                return True

            win.add(nums[i])
            if i >= k:
                win.discard(nums[i - k])

        return False
コード例 #10
0
    def maxSumRow(self, row, k):
        ans = float('-inf')
        total = 0

        ss = SortedSet()
        ss.add(0)

        for n in row:
            total += n
            i = ss.bisect_left(total - k)
            if i < len(ss):
                x = ss[i]
                ans = max(ans, total - x)
            ss.add(total)
        return ans
コード例 #11
0
def generateOutput(inputIterators):
    ss = SortedSet()
    for circIter in inputIterators:
        print(datetime.datetime.now().strftime("%H:%M:%S") + " Merging " +
              circIter.name)
        for circ in circIter:
            pos = ss.bisect_left(circ)
            if pos < len(ss) and circ == ss[pos]:
                ss[pos].merge(circ)
            elif isinstance(circIter,
                            CircDatasetIter) or circ.group.strand == ".":
                ss.add(circ)

        print("(total merged circrnas: %d)" % (len(ss)))
    mergeUnknownStrands(ss, circIters)
    return ss
コード例 #12
0
    def containsNearbyAlmostDuplicate(self, nums: List[int], k: int,
                                      t: int) -> bool:
        st = SortedSet()

        for i in range(len(nums)):
            num = nums[i]
            index = st.bisect_left(num - t)
            print(st)
            if index < len(st) and st[index] <= num + t:
                return True

            st.add(num)
            if len(st) > k:
                st.remove(nums[i - k])

        return False
    def containsNearbyAlmostDuplicate(self, nums: List[int], k: int,
                                      t: int) -> bool:
        from sortedcontainers import SortedSet as SS
        ss = SS()
        for i, item in enumerate(nums):

            ID = ss.bisect_left(item)
            if 0 <= ID < len(ss):
                if abs(item - ss[ID]) <= t:
                    return True
            if 0 <= ID - 1:
                if abs(item - ss[ID - 1]) <= t:
                    return True

            ss.add(item)

            if k <= i:
                l = i - k
                ss.remove(nums[l])
コード例 #14
0
    def closestRoom(self, rooms: List[List[int]],
                    queries: List[List[int]]) -> List[int]:
        n = len(rooms)
        rooms.sort(key=lambda x: x[1])
        ids = SortedSet(room[0] for room in rooms)
        for i, query in enumerate(queries):
            query.append(i)
        queries.sort(key=lambda x: x[1])

        def searchSize(size: int) -> int:
            left, right = 0, n
            while left < right:
                mid = (left + right) // 2
                if rooms[mid][1] < size:
                    left = mid + 1
                else:
                    right = mid
            return left

        ans = [-1] * len(queries)
        pre = 0
        for preferred, minSize, idx in queries:
            cur = searchSize(minSize)
            if cur == n:
                continue

            while pre < cur:
                ids.discard(rooms[pre][0])
                pre += 1

            lt = ids.bisect_left(preferred)
            if lt == len(ids):
                if lt > 0:
                    ans[idx] = ids[lt - 1]
            else:
                ans[idx] = ids[lt]
                tmp = ids[lt] - preferred
                if lt > 0 and preferred - ids[lt - 1] <= tmp:
                    ans[idx] = ids[lt - 1]

        return ans
コード例 #15
0
10000000
sl[-3:]
['e', 'e', 'e']

from sortedcontainers import SortedDict
sd = SortedDict({'c': 3, 'a': 1, 'b': 2})
sd
SortedDict({'a': 1, 'b': 2, 'c': 3})
sd.popitem(index=-1)
('c', 3)

from sortedcontainers import SortedSet
ss = SortedSet('abracadabra')
ss
SortedSet(['a', 'b', 'c', 'd', 'r'])
ss.bisect_left('c')
2

################################################
name_height = {'simon': 177, 'helen': 166}
name_height.get('dudu', 30)

name_default = defaultdict(list)
name_default[0]  # no error, default is []
name_default[1].append(4)
name_default[1].append(4)
name_default[1].append(4)
name_default

################################################
# tuple unpacking
コード例 #16
0
def test_bisect():
    temp = SortedSet(range(100), load=7)
    assert all(temp.bisect_left(val) == val for val in range(100))
    assert all(temp.bisect(val) == val for val in range(100))
    assert all(temp.bisect_right(val) == (val + 1) for val in range(100))
コード例 #17
0
# pip install sortedcontainers

from sortedcontainers import SortedList
sl = SortedList(['e', 'a', 'c', 'd', 'b'])
print(sl)
SortedList(['a', 'b', 'c', 'd', 'e'])
sl *= 10_000_000
print(sl.count('c'))
print(sl[-3:])
from sortedcontainers import SortedDict
sd = SortedDict({'c': 3, 'a': 1, 'b': 2})
print(sd)
SortedDict({'a': 1, 'b': 2, 'c': 3})
print(sd.popitem(index=-1))


from sortedcontainers import SortedSet
ss = SortedSet('abracadabra')
print(ss)
SortedSet(['a', 'b', 'c', 'd', 'r'])
print(ss.bisect_left('c'))
コード例 #18
0
    def oddEvenJumps(self, arr: List[int]) -> int:
        # Time Complexity: O(N log N)
        # Space Complexity: O(N)

        nearest_larger = []
        nearest_smaller = []

        sorted_set = SortedSet()
        for i in range(len(arr) - 1, -1, -1):
            val = arr[i]

            if sorted_set:
                ind = sorted_set.bisect_left((val, -i))

                if ind == 0:
                    nearest_smaller.append(len(arr))
                else:
                    nearest_smaller.append(-sorted_set[ind - 1][1])
            else:
                nearest_smaller.append(len(arr))

            sorted_set.add((val, -i))

        nearest_smaller = nearest_smaller[::-1]

        sorted_set = SortedSet()
        for i in range(len(arr) - 1, -1, -1):
            val = arr[i]

            if sorted_set:
                ind = sorted_set.bisect_left((-val, -i))

                if ind == 0:
                    nearest_larger.append(len(arr))
                else:
                    nearest_larger.append(-sorted_set[ind - 1][1])
            else:
                nearest_larger.append(len(arr))

            sorted_set.add((-val, -i))

        nearest_larger = nearest_larger[::-1]

        dp = {}

        def reachable(index: int, is_even: bool):
            if (index, is_even) in dp:
                return dp[(index, is_even)]

            if index == len(arr) - 1:
                return True

            next_arr = nearest_smaller if is_even else nearest_larger
            ret = next_arr[index] < len(arr) and reachable(
                next_arr[index], not is_even)

            dp[(index, is_even)] = ret
            return ret

        ret = 0
        for i in range(len(arr)):
            ret += int(reachable(i, False))

        return ret
コード例 #19
0
ファイル: PointOut.py プロジェクト: lungnahahd/Python_Prac
# 점 뺴기
## 2차 평면 위에 서로 다른 n개의 점이 주어집니다. 이후 m개의 질의가 주어지는데, 각 질의마다는 한 개의 숫자 k가 주어집니다. 각 질의에 대해 주어진 숫자 k보다 x값이 같거나 큰 점 중 x값이 가장 작은 점을 찾아 지우려고 합니다. 
## 만약 x값이 가장 작은 점이 여러 개라면, 그 중 y값이 가장 작은 점을 지우면 됩니다. 각 질의에 대해 해당하는 점을 순서대로 출력하고 지우는 프로그램을 작성해보세요.
### TreeSet 사용 이유 : 큰 것중에 작은 것을 빼고, y 값 역시 동시에 비교해주기 위해 사용



from sortedcontainers import SortedSet

s = SortedSet()

cmd = list(map(int, input().split()))
for i in range(cmd[0]):
    x, y = input().split()
    s.add((int(x),int(y)))

for i in range(cmd[1]):
    numX = int(input())
    where = s.bisect_left((numX,0))
    if where >= len(s):
        print(-1, -1)
    else:
        x, y = s[where]
        print(x, y)
        s.remove(s[where])
コード例 #20
0
#   http://www.grantjenks.com/docs/sortedcontainers/
from sortedcontainers import SortedList
sl = SortedList(['e', 'a', 'c', 'd', 'b'])
print(sl)
sl *= 10000000
print(sl.count('c'))
print(sl[-3:])
['e', 'e', 'e']

from sortedcontainers import SortedDict
sd = SortedDict({'c': 3, 'a': 1, 'b': 2})
print(sd)
SortedDict({'a': 1, 'b': 2, 'c': 3})
print(sd.popitem())

from sortedcontainers import SortedSet
ss = SortedSet('abracadabra')
print(ss)
print(ss.bisect_left('c'))
コード例 #21
0
### TreeSet 사용 이유 : 각 위치에 있을 떄, 가장 가깝게 크고, 가장 가깝게 작은 수를 빠르게 선택해야 되므로 사용

import sys
INT_MAX = sys.maxsize
from sortedcontainers import SortedSet

s = SortedSet()  # treeset

cmd = list(map(int, input().split()))

result = INT_MAX
for i in range(cmd[0]):
    num = int(input())
    bigNum = num + cmd[1]
    smallNum = num - cmd[1]
    bigIdx = s.bisect_left(bigNum)
    smallIdx = s.bisect_right(smallNum) - 1
    if smallIdx < 0 and bigIdx == len(s):
        s.add(num)
        continue
    elif bigIdx == len(s):
        temp = num - s[smallIdx]
    elif smallIdx < 0:
        temp = s[bigIdx] - num
    else:
        smallTemp = num - s[smallIdx]
        bigTemp = s[bigIdx] - num
        temp = min(smallTemp, bigTemp)
    if result > temp:
        result = temp
    s.add(num)
コード例 #22
0
        # d = collections.defaultdict(lambda: sys.maxsize)  # A bucket simulated with defaultdict
        # for i, num in enumerate(nums):
        #     key = bucket_key(num)                         # key for current number `num`
        #     for nei in [d[key-1], d[key], d[key+1]]:      # check left bucket, current bucket and right bucket
        #         if abs(nei - num) <= t: return True
        #     d[key] = num    
        #     if i >= k: d.pop(bucket_key(nums[i-k]))       # maintain a size of `k` 
        # return False
        
        #solution by idontknoooo, use SortedSet, time O(Nlog(min(N,k))), space O(min(N,k))
        from sortedcontainers import SortedSet
        # Create SortedSet. `n` is the size of sortedset, max value of `n` is `k` from input
        ss, n = SortedSet(), 0                 
        for i, num in enumerate(nums):
            # index whose value is greater than or equal to `num`
            ceiling_idx = ss.bisect_left(num)  
            # index whose value is smaller than `num`
            floor_idx = ceiling_idx - 1        
            if ceiling_idx < n and abs(ss[ceiling_idx]-num) <= t: # check right neighbour  
                return True  
            if 0 <= floor_idx and abs(ss[floor_idx]-num) <= t: # check left neighbour
                return True
            ss.add(num)
            n += 1
            if i - k >= 0:  # maintain the size of sortedset by finding & removing the earliest number in sortedset
                ss.remove(nums[i-k])
                n -= 1
        return False
    
        # #Right idea but time O(N^2), TLE
        # htb = {}
コード例 #23
0
def test_bisect():
    temp = SortedSet(range(100))
    temp._reset(7)
    assert all(temp.bisect_left(val) == val for val in range(100))
    assert all(temp.bisect(val) == (val + 1) for val in range(100))
    assert all(temp.bisect_right(val) == (val + 1) for val in range(100))
コード例 #24
0
ファイル: selection.py プロジェクト: Sebelino/pyromhackit
class Selection(IMutableGSlice):
    def __init__(
            self,
            universe: slice,
            revealed: list = None,
            intervals: Iterator = None,
            _length: Optional[int] = None  # For performance
    ):
        #assert isinstance(universe, slice)  # Should universe even be visible/exist?
        #assert universe.start == 0
        #assert isinstance(universe.stop, int)
        #assert universe.stop >= 1  # TODO Do we need this?
        self.universe = universe
        if intervals is None and revealed is None:
            self._intervals = self.revealed2sortedset([slice(0, universe.stop)])
        elif intervals is not None:
            self._intervals = SortedSet(intervals)
        else:
            self._intervals = self.revealed2sortedset(revealed)
        self._revealed_count = _length if isinstance(_length, int) else Selection._compute_len(self._intervals)

    @staticmethod
    def revealed2sortedset(revealed: List[Union[tuple, slice]]) -> SortedSet:
        """ Converts a list of included pairs to a sorted set of integers in O(n), n = size of @slices.
        Every number from every slice is added to the sorted set, except 0.
        """
        # 10, [] -> 10, []
        # 10, [(0, 10)] -> 10, [10]
        # 10, [(0, 7)] -> 10, [7]
        # 10, [(7, 10)] -> 10, [7, 10]
        # 10, [(3, 7)] -> 10, [3, 7]
        # 10, [(0, 3), (7, 10)] -> 10, [3, 7, 10]
        # 10, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)] -> 10, [1, 2, 3, 4, 5, 6, 7, 8, 9]

        try:
            #intervals = SortedSet(a for a, _ in revealed).union(b for _, b in revealed)
            intervals = SortedSet()
            for a, b in revealed:
                intervals.add(a)
                intervals.add(b)
        except TypeError:  # slice
            intervals = SortedSet(sl.start for sl in revealed).union(sl.stop for sl in revealed)
        if 0 in intervals:
            intervals.remove(0)
        return intervals

    @staticmethod
    def sortedset2slices(sortedset: SortedSet) -> List[slice]:
        """ Converts a sorted set of integers to a list of included slices in O(n), n = size of @sortedset.
        If there is an even number of elements in @sortedset, the first slice is formed by the first and second
        numbers, the second slice is formed by the third and fourth numbers, and so on.
        If there is an odd number of elements in @sortedset, the pair consisting of the number 0 and the first element
        in @sortedset becomes the first slice in the output list. The remaining slices, if any, are formed by the
        second and third numbers, the fourth and fifth numbers, and so on.
        """
        slices = []
        if len(sortedset) % 2 == 0:
            for i in range(0, len(sortedset), 2):
                slices.append(slice(sortedset[i], sortedset[i + 1]))
        else:
            slices.append(slice(0, sortedset[0]))
            for i in range(1, len(sortedset), 2):
                slices.append(slice(sortedset[i], sortedset[i + 1]))
        return slices

    def slices(self) -> List[slice]:
        return self.sortedset2slices(self._intervals)

    def pairs(self) -> Iterator[Tuple[int, int]]:
        if len(self._intervals) % 2 == 0:
            return zip(self._intervals[::2], self._intervals[1::2])
        return itertools.chain([(0, self._intervals[0])], zip(self._intervals[1::2], self._intervals[2::2]))

    def gap_pairs(self) -> Iterator[Tuple[int, int]]:
        return self.complement().pairs()

    def intervals(self):
        return self._intervals

    def exclude(self, from_index: Optional[int], to_index: Optional[int]):
        original_length = self._revealed_count
        if isinstance(from_index, int) and -self.universe.stop <= from_index < 0:
            from_index = from_index % self.universe.stop
        if isinstance(to_index, int):
            if to_index > self.universe.stop:
                return self.exclude(from_index, None)
            if -self.universe.stop <= to_index < 0:
                to_index = to_index % self.universe.stop
        assert from_index is None or self.universe.start <= from_index <= self.universe.stop
        assert to_index is None or self.universe.start <= to_index <= self.universe.stop
        if from_index is None:
            from_index = self.universe.start
        if to_index is None:
            to_index = self.universe.stop
        if len(self._intervals) == 0:
            return 0
        if from_index >= to_index:
            return 0

        m = self._intervals.bisect_right(from_index)
        n = self._intervals.bisect_right(to_index)

        try:
            from_index_index = self._intervals.index(from_index)
        except ValueError:
            from_index_index = None
        try:
            to_index_index = self._intervals.index(to_index)
        except ValueError:
            to_index_index = None
        from_index_is_included = (
            len(self._intervals) % 2 == 0 and m % 2 == 1 or len(self._intervals) % 2 == 1 and m % 2 == 0)
        to_index_is_included = (
            len(self._intervals) % 2 == 0 and n % 2 == 1 or len(self._intervals) % 2 == 1 and n % 2 == 0)
        from_index_is_leftmost_included = from_index == 0 and from_index_is_included or from_index_index is not None and (
                len(self._intervals) % 2 == 0 and from_index_index % 2 == 0
                or len(self._intervals) % 2 == 1 and (from_index == 0 or from_index_index % 2 == 1))
        to_index_right_of_excluded = to_index_index is not None and (
                len(self._intervals) % 2 == 0 and to_index_index % 2 == 1
                or len(self._intervals) % 2 == 1 and (to_index == 0 or to_index_index % 2 == 0))

        if from_index_is_included:
            if from_index_is_leftmost_included:
                if to_index_is_included:
                    if m == 0:
                        to_remove = self._intervals[m:n]
                        endpoint = 0 if n == 0 else self._intervals[n - 1]
                        addendum = 0 if n == 0 else self._intervals[0]
                        self._revealed_count -= (to_index - endpoint) + addendum + sum(
                            b - a for a, b in zip(to_remove[1:-1:2], to_remove[2:-1:2]))
                        del self._intervals[m:n]
                        self._intervals.add(to_index)
                    else:
                        intermediates = self._intervals[m + 1:n - 1]
                        from_start, from_end = self._intervals[m - 1], self._intervals[m]
                        to_start, to_end = self._intervals[n - 1], self._intervals[n]
                        if m == n:
                            self._revealed_count -= to_index - from_start
                            self._intervals.remove(from_start)
                            self._intervals.add(to_index)
                        else:
                            self._revealed_count -= (from_end - from_start) + (to_index - self._intervals[n - 1]) + (
                                from_index - from_start) + sum(
                                b - a for a, b in zip(intermediates[::2], intermediates[1::2]))
                            del self._intervals[m + 1:n - 1]  # intermediates
                            self._intervals.remove(from_start)
                            self._intervals.remove(from_end)
                            self._intervals.remove(to_start)
                            self._intervals.add(to_index)
                else:
                    from_start = 0 if m == 0 else self._intervals[m - 1]
                    from_end = self._intervals[m]
                    self._revealed_count -= from_end - from_start
                    if from_start > 0:
                        self._intervals.remove(from_start)
                    self._intervals.remove(from_end)
            else:
                if to_index_is_included:
                    from_end = self._intervals[m]
                    to_start = self._intervals[n - 1]
                    if m == n:
                        self._revealed_count -= to_index - from_index
                        if from_index > 0:
                            self._intervals.add(from_index)
                        self._intervals.add(to_index)
                    else:
                        intermediates = self._intervals[m + 1:n - 1]
                        self._revealed_count -= (from_end - from_index) + (to_index - to_start) + sum(
                            b - a for a, b in zip(intermediates[::2], intermediates[1::2]))
                        del self._intervals[m + 1:n - 1]  # intermediates
                        if from_index > 0:
                            self._intervals.add(from_index)
                        self._intervals.add(to_index)
                        self._intervals.remove(from_end)
                        self._intervals.remove(to_start)
                else:
                    to_remove = self._intervals[m:n]
                    self._revealed_count -= self._intervals[m] - from_index + sum(b - a for a, b in zip(to_remove[1::2], to_remove[::2]))
                    del self._intervals[m:n]
                    if from_index != 0:
                        self._intervals.add(from_index)
        else:
            if to_index_is_included:
                if to_index_right_of_excluded:
                    to_remove = self._intervals[m:n - 1]
                    del self._intervals[m:n - 1]
                    self._revealed_count -= sum(b - a for a, b in zip(to_remove[::2], to_remove[1::2]))
                else:
                    to_remove = self._intervals[m:n]
                    del self._intervals[m:n]
                    self._intervals.add(to_index)
                    self._revealed_count -= (to_index - to_remove[0]) + sum(b - a for a, b in zip(to_remove[1::2], to_remove[::2]))
            else:
                to_remove = self._intervals[m:n]
                del self._intervals[m:n]
                self._revealed_count -= sum(b - a for a, b in zip(to_remove[::2], to_remove[1::2]))

        return original_length - self._revealed_count

    def exclude_virtual(self, from_index: Optional[int], to_index: Optional[int]):
        if from_index is None or from_index < -len(self) or from_index >= len(self):
            p_from_index = None
        else:
            p_from_index = self.virtual2physical(from_index)
        if to_index is None or to_index < -len(self) or to_index >= len(self):
            p_to_index = None
        else:
            p_to_index = self.virtual2physical(to_index)
        return self.exclude(p_from_index, p_to_index)

    def include(self, from_index: Optional[int], to_index: Optional[int]):
        original_length = len(self)
        if isinstance(from_index, int) and -self.universe.stop <= from_index < 0:
            from_index = from_index % self.universe.stop
        if isinstance(to_index, int):
            if to_index > self.universe.stop:
                return self.include(from_index, None)
            if -self.universe.stop <= to_index < 0:
                to_index = to_index % self.universe.stop
        assert from_index is None or self.universe.start <= from_index <= self.universe.stop
        assert to_index is None or self.universe.start <= to_index <= self.universe.stop
        if from_index is None:
            from_index = self.universe.start
        if to_index is None:
            to_index = self.universe.stop
        if not self._intervals:
            if from_index > 0:
                self._intervals.add(from_index)
            self._intervals.add(to_index)
            self._revealed_count += to_index - from_index
            return to_index - from_index
        if from_index == to_index:
            return 0

        m = self._intervals.bisect_right(from_index)
        n = self._intervals.bisect_right(to_index)

        try:
            from_index_index = self._intervals.index(from_index)
        except ValueError:
            from_index_index = None

        from_index_is_included = (
                len(self._intervals) % 2 == 0 and m % 2 == 1 or len(self._intervals) % 2 == 1 and m % 2 == 0)
        to_index_is_included = (
                len(self._intervals) % 2 == 0 and n % 2 == 1 or len(self._intervals) % 2 == 1 and n % 2 == 0)
        from_index_right_of_included = from_index_index is not None and (
                len(self._intervals) % 2 == 0 and from_index_index % 2 == 1
                or len(self._intervals) % 2 == 1 and from_index_index % 2 == 0)

        if from_index_is_included:
            if to_index_is_included:
                to_remove = self._intervals[m:n]
                del self._intervals[m:n]
                self._revealed_count += sum(b - a for a, b in zip(to_remove[::2], to_remove[1::2]))
            else:
                to_remove = self._intervals[m:n]
                del self._intervals[m:n]
                self._intervals.add(to_index)
                self._revealed_count += (to_index - to_remove[-1]) + sum(b - a for a, b in zip(to_remove[1::2], to_remove[::2]))
        else:
            if to_index_is_included:
                if from_index_right_of_included:
                    to_remove = self._intervals[m - 1:n]
                    del self._intervals[m - 1:n]
                    self._revealed_count += sum(b - a for a, b in zip(to_remove[::2], to_remove[1::2]))
                else:
                    to_remove = self._intervals[m:n]
                    del self._intervals[m:n]
                    self._intervals.add(from_index)
                    self._revealed_count += (to_remove[0] - from_index) + sum(b - a for a, b in zip(to_remove[1::2], to_remove[::2]))
            else:
                if from_index_right_of_included:
                    intermediates = self._intervals[m:n]
                    del self._intervals[m:n]  # intermediates
                    self._intervals.remove(from_index)
                    self._intervals.add(to_index)
                    self._revealed_count += (to_index - from_index) - sum(b - a for a, b in zip(intermediates[::2], intermediates[1::2]))
                else:
                    to_remove = self._intervals[m:n]
                    del self._intervals[m:n]
                    if from_index > 0:
                        self._intervals.add(from_index)
                    self._intervals.add(to_index)
                    self._revealed_count += (to_index - from_index) - sum(b - a for a, b in zip(to_remove[::2], to_remove[1::2]))

        return len(self) - original_length

    def include_partially(self, from_index: Optional[int], to_index: Optional[int], count: Union[int, tuple]):
        if isinstance(count, int):
            return self.include_partially(from_index, to_index, (count, count))
        head_count, tail_count = count
        head_revealed_count = self._include_partially_from_left(from_index, to_index, head_count)
        tail_revealed_count = self._include_partially_from_right(from_index, to_index, tail_count)
        return head_revealed_count + tail_revealed_count

    def _include_partially_from_left(self, from_index: int, to_index: int, count: int):
        if count == 0:
            return 0
        from_index, to_index = self._normalized_range(from_index, to_index)
        subsel = self._spanning_subslice(from_index, to_index).complement().subslice(from_index, to_index)

        revealed_count = 0
        for covered_start, covered_stop in subsel.pairs():
            coverage = covered_stop - covered_start
            if revealed_count + coverage < count:
                self.include(covered_start, covered_stop)
                revealed_count += coverage
            else:
                self.include(covered_start, covered_start + count - revealed_count)
                revealed_count = count
                break
        return revealed_count

    def _include_partially_from_right(self, from_index: int, to_index: int, count: int):
        if count == 0:
            return 0
        from_index, to_index = self._normalized_range(from_index, to_index)
        subsel = self._spanning_subslice(from_index, to_index).complement().subslice(from_index, to_index)

        revealed_count = 0
        for covered_start, covered_stop in reversed(list(subsel.pairs())):
            coverage = covered_stop - covered_start
            if revealed_count + coverage < count:
                self.include(covered_start, covered_stop)
                revealed_count += coverage
            else:
                self.include(covered_stop - (count - revealed_count), covered_stop)
                revealed_count = count
                break
        return revealed_count

    def include_expand(self, from_index: Optional[int], to_index: Optional[int], count: Union[int, Tuple[int, int]]):
        if isinstance(count, int):
            return self.include_expand(from_index, to_index, (count, count))
        if count == (0, 0):
            return 0
        head_count, tail_count = count
        revealed_counter = 0
        gaps = self.complement().subslice(from_index, to_index)
        for a, b in gaps.pairs():
            if b < self.universe.stop:
                revealed_counter += self._include_partially_from_right(a, b, head_count)
            if a > self.universe.start:
                revealed_counter += self._include_partially_from_left(a, b, tail_count)
        return revealed_counter

    def _previous_slice(self, sl: slice):
        """ :return The revealed or covered slice immediately to the left of @sl.
        :raise ValueError if there is none. """
        if sl.start == self.universe.start:
            raise ValueError("There is no slice to the left of {}.".format(sl))
        # TODO O(n) -> O(1)
        zero_or_one = [s for s in self._intervals + self.complement()._intervals if s.stop == sl.start]
        if len(zero_or_one) == 1:
            return zero_or_one[0]
        else:
            raise ValueError("Slice not found: {}.".format(sl))

    def _next_slice(self, sl: slice):
        """ :return The revealed or covered slice immediately to the right of @sl.
        :raise ValueError if there is none. """
        if sl.stop == self.universe.stop:
            raise ValueError("There is no slice to the right of {}.".format(sl))
        # TODO O(n)
        zero_or_one = [s for s in self._intervals + self.complement()._intervals if s.start == sl.stop]
        if len(zero_or_one) == 1:
            return zero_or_one[0]
        else:
            raise ValueError("Slice not found: {}.".format(sl))

    def include_virtual(self, from_index, to_index):
        if from_index is None or from_index < -len(self) or from_index >= len(self):
            p_from_index = None
        else:
            p_from_index = self.virtual2physical(from_index)
        if to_index is None or to_index < -len(self) or to_index >= len(self):
            p_to_index = None
        else:
            p_to_index = self.virtual2physical(to_index)
        return self.include(p_from_index, p_to_index)

    def include_partially_virtual(self, from_index: Optional[int], to_index: Optional[int], count: Union[int, tuple]):
        if from_index is None or from_index < -len(self) or from_index >= len(self):
            p_from_index = None
        else:
            p_from_index = self.virtual2physical(from_index)
        if to_index is None or to_index < -len(self) or to_index >= len(self):
            p_to_index = None
        else:
            p_to_index = self.virtual2physical(to_index)
        return self.include_partially(p_from_index, p_to_index, count)

    # FIXME Inconsistent with reversed(selection). Should probably make this use the default implementation and instead
    # rewrite this one to iter_slices or something.
    def __iter__(self):
        for a, b in self.pairs():
            yield a, b  # FIXME should probably generate slices instead, or every index

    def complement(self):
        if len(self._intervals) >= 1 and self._intervals[-1] == self.universe.stop:
            return Selection(universe=self.universe, intervals=self._intervals[:-1],
                             _length=self.universe.stop - len(self))
        return Selection(universe=self.universe, intervals=self._intervals.union([self.universe.stop]),
                         _length=self.universe.stop - len(self))

    def _normalized_range(self, from_index: Optional[int], to_index: Optional[int]) -> Tuple[int, int]:
        """ For any range [@from_index, @to_index) where the indices are either None or any integer, returns the
        equivalent range [x, y) such that either 0 <= x < y <= upper_bound or x = y = 0. The ranges are equivalent in
        the sense that when using them to slice this selection, they produce the same sub-selection. """
        if from_index is None or from_index <= -self.universe.stop:
            from_index = self.universe.start
        elif from_index > self.universe.stop:
            from_index = self.universe.stop
        elif -self.universe.stop <= from_index < 0:
            from_index = self.universe.stop - from_index

        if to_index is None or to_index >= self.universe.stop:
            to_index = self.universe.stop
        elif -self.universe.stop <= to_index < 0:
            to_index = self.universe.stop - to_index
        elif to_index < -self.universe.stop:
            to_index = self.universe.start

        if from_index >= to_index:
            from_index, to_index = (0, 0)
        return from_index, to_index

    def subslice(self, from_index: Optional[int], to_index: Optional[int]):
        from_index, to_index = self._normalized_range(from_index, to_index)
        sel = self._spanning_subslice(from_index, to_index)
        if len(sel._intervals) % 2 == 0:
            if len(sel) > 0:
                if sel._intervals[0] < from_index < sel._intervals[1]:
                    sel._revealed_count -= from_index - sel._intervals[0]
                    del sel._intervals[0]
                    sel._intervals.add(from_index)
                if sel._intervals[-2] < to_index < sel._intervals[-1]:
                    sel._revealed_count -= sel._intervals[-1] - to_index
                    del sel._intervals[-1]
                    sel._intervals.add(to_index)
        else:
            if 0 < from_index < sel._intervals[0]:
                sel._revealed_count -= from_index
                sel._intervals.add(from_index)
            if (len(sel._intervals) == 1 and to_index < sel._intervals[-1]
                    or len(sel._intervals) >= 2 and sel._intervals[-2] < to_index < sel._intervals[-1]):
                sel._revealed_count -= sel._intervals[-1] - to_index
                del sel._intervals[-1]
                sel._intervals.add(to_index)
        return sel

    def _spanning_subslice(self, from_index: int, to_index: int):
        """ :return A Selection whose set of revealed slices is a subset of that of this Selection such that every index
        in [from_index, to_index) is either on some slice in the subset, or on a gap. """
        if from_index >= to_index:
            return Selection(universe=deepcopy(self.universe), intervals=[])
        m = self._intervals.bisect_right(from_index)
        if len(self._intervals) % 2 == 0:
            n = self._intervals.bisect_left(to_index)
            intervals = self._intervals[m - (m % 2):n + (n % 2)]
        else:
            n = self._intervals.bisect_right(to_index)
            a = max(0, m - ((m + 1) % 2))
            b = n + ((n + 1) % 2)
            intervals = self._intervals[a:b]
        sel = Selection(universe=deepcopy(self.universe), intervals=intervals)
        return sel

    def _slow_subslice(self, from_index: Optional[int], to_index: Optional[int]):
        sel = self.deepcopy()
        if isinstance(from_index, int):
            sel.exclude(None, from_index)
        if isinstance(to_index, int):
            sel.exclude(to_index, None)
        return sel

    def _interval_index(self, pindex):
        """ :return n if the nth interval edge is the smallest number such that @pindex < n (zero-indexed). """
        lower = 0
        upper = len(self._intervals) - 1
        while lower <= upper:
            middle = (lower + upper) // 2
            midsl = self._intervals[middle]
            if pindex < midsl.start:
                upper = middle - 1
            elif midsl.stop <= pindex:
                lower = middle + 1
            else:  # midsl.start <= pindex < midsl.stop:
                return middle
        raise IndexError("{} is not in any interval.".format(pindex))

    def select(self, listlike):
        # TODO only works for stringlike objects
        lst = []
        for interval in self.slices():
            lst.append(listlike[interval])
        selection = listlike[0:0].join(lst)
        return selection

    def physical2virtual(self, pindex: int):
        vindex = 0
        for a, b in self.pairs():
            if a <= pindex < b:
                vindex += pindex - a
                return vindex
            vindex += b - a
        raise IndexError("Physical index {} out of bounds for selection {}".format(pindex, self))

    # TODO: O(n) -> O(log(n)) (using another sorted set for cumulative lengths?)
    def virtual2physical(self, vindex: int):  # TODO -> virtualint2physical
        """ :return the integer n such that where the @vindex'th revealed element is the nth element. If
        @vindex < 0, @vindex is interpreted as (number of revealed elements) + @vindex.
        """
        if vindex < -len(self):
            raise IndexError(
                "Got index {}, expected it to be within range [{},{})".format(vindex, -len(self), len(self)))
        elif vindex < 0:
            return self.virtual2physical(len(self) + vindex)
        cumlength = 0
        for a, b in self.pairs():
            cumlength += b - a
            if vindex < cumlength:
                pindex = b - (cumlength - vindex)
                if a <= pindex < b:
                    return pindex
                else:
                    break
        raise IndexError("Virtual index {} out of bounds for selection {}".format(vindex, self))

    def virtual2physicalselection(self, vslice: slice) -> 'Selection':  # TODO -> virtualslice2physical
        """ :return the sub-Selection that is the intersection of this selection and @vslice. """
        if not self._intervals or vslice.stop == 0:
            return Selection(self.universe, revealed=[])
        if vslice.start is None:
            a = self.virtual2physical(0)
        elif -len(self) <= vslice.start < len(self):
            a = self.virtual2physical(vslice.start)
        elif vslice.start >= len(self):
            a = self._intervals[-1]
        else:
            raise ValueError("Unexpected slice start: {}".format(vslice))
        if vslice.stop is None or vslice.stop >= len(self):
            b = self._intervals[-1] - 1
        elif -len(self) <= vslice.stop < len(self):
            b = self.virtual2physical(vslice.stop - 1)
        else:
            raise ValueError("Unexpected slice stop: {}".format(vslice))
        # INV: a is the physical index of the first element, b is the physical index of the last element
        if b < a:
            return Selection(universe=self.universe, revealed=[])
        m = self._intervals.bisect_right(a)
        n = self._intervals.bisect_right(b)
        intervals = SortedSet([a] + self._intervals[m:n] + [b + 1])
        return Selection(universe=self.universe, intervals=intervals)

    def virtualselection2physical(self, vselection: 'Selection'):  # TODO -> virtualslice2physical
        """ :return the sub-Selection that is the intersection of this selection and @vselection. """
        intervals = []
        for start, stop in vselection:
            for a, b in self.virtual2physicalselection(slice(start, stop)):
                intervals.append(slice(a, b))
        return Selection(universe=self.universe, revealed=intervals)

    def stretched(self, from_index: Optional[int], to_index: Optional[int]):  # TODO remove?
        """ :return A potentially shrinked deep copy of this selection, delimited by the universe
        [@from_index, @to_index). """
        m = self._intervals.bisect_right(from_index)
        n = self._intervals.bisect_right(to_index)
        intervals = self._intervals[m:n]
        return Selection(universe=slice(from_index, to_index), intervals=intervals)

    def __getitem__(self, item):
        return self.virtual2physical(item)

    @staticmethod
    def _compute_len(sortedset: SortedSet):
        """ :return The sum of the lengths of every slice in @slicelist. """
        if len(sortedset) == 0:
            return 0
        elif len(sortedset) % 2 == 0:
            return sum(sortedset[i + 1] - sortedset[i] for i in range(0, len(sortedset), 2))
        return sortedset[0] + sum(sortedset[i + 1] - sortedset[i] for i in range(1, len(sortedset), 2))

    def __len__(self):
        return self._revealed_count

    def __eq__(self, other):
        return repr(self) == repr(other)

    def __mul__(self, other: int):
        if other == 0:
            return Selection(universe=slice(0, 0), revealed=[])
        scaled_universe = slice(self.universe.start * other, self.universe.stop * other)
        scaled_revealed = [other * x for x in self._intervals]
        return Selection(universe=scaled_universe, intervals=scaled_revealed)

    def __rmul__(self, other):
        return self.__mul__(other)

    def __repr__(self):
        return "{}(universe={}, intervals={})".format(self.__class__.__name__, self.universe, self._intervals)

    def __str__(self):
        return repr(self)

    def deepcopy(self):
        """ :return A deep copy of this object. """
        return Selection(universe=deepcopy(self.universe), intervals=deepcopy(self._intervals))
コード例 #25
0
class MutableIntervalDict(
        # pylint: disable=unsubscriptable-object
        IntervalDict[atomic.TO, V],
        MutableMapping[atomic.Interval[atomic.TO], V],
):
    """
    Mutable Interval Dictionary class.

    The :class:`MutableIntervalDict` class (which inherits from the
    :class:`IntervalDict` class) is designed to hold mutable dict of disjoint sorted
    intervals.

    Note
    ----
        Les :math:`n` (or :math:`n_0`)  the number of intervals of the *self* variable
        and :math:`m` the number of intervals in the *other* variable. Let
        :math:`n_1, ... n_k` the number of intervals for methods with multiple
        arguments.

        The complexity in time of methods is:

        =========================  ====================================================
        Methods                    Average case
        =========================  ====================================================
        :meth:`__setitem__`        :math:`O(n)`
        :meth:`__delitem__`        :math:`O(n)`
        :meth:`__ior__`            :math:`O(m\\log(n+m))`
        :meth:`update`             :math:`O((\\sum_{i=1}^kn_i)\\log(\\sum_{i=0}^kn_i))`
        :meth:`clear`              :math:`O(1)`
        =========================  ====================================================
    """

    __slots__ = ("_default", "_operator", "_strict")

    def __init__(
        self,
        iterable: Optional[Union[IntervalDict[atomic.TO, V], Mapping[
            atomic.IntervalValue[atomic.TO],
            V], Iterable[Tuple[atomic.IntervalValue[atomic.TO], V]], ]] = None,
        default: Optional[Callable[[], V]] = None,
        operator: Optional[Callable[[V, V], V]] = None,
        strict: Optional[bool] = True,
    ) -> None:
        """
        Initialize a :class:`MutableIntervalDict` instance.

        Arguments
        ---------
            iterable: :class:`Iterable <python:typing.Iterable>`
                An optional iterable that can be converted to a dictionary of (
                interval, value).

        Keyword arguments
        -----------------
            default: :class:`Callable[[], V] <python:typing.Callable>`, optional
                The default factory.
            operator: :class:`Callable[[V, V], V] <python:typing.Callable>`
                The operator function.
            strict: bool
                :data:`False <python:False>` if ``operator`` is a commutative and
                associative law on ``V``.

        Note
        ----

            If  ``operator`` is a commutative and associative law on ``V``,
            the complexity in time is much faster if ``strict`` is set to
            :data:`False <python:False>`.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> a = MutableIntervalDict[int, set](
            ...     operator=lambda x, y: x | y,
            ...     strict=False
            ... )
            >>> a.update({(1, 10): {1}})
            >>> print(a)
            {'[1;10)': {1}}
            >>> a.update({(5, 20): {2}})
            >>> print(a)
            {'[1;5)': {1}, '[5;10)': {1, 2}, '[10;20)': {2}}
            >>> a.update({(10, 30): {1}})
            >>> print(a)
            {'[1;5)': {1}, '[5;10)': {1, 2}, '[10;20)': {1, 2}, '[20;30)': {1}}
            >>> print(a.compress())
            {'[1;5)': {1}, '[5;20)': {1, 2}, '[20;30)': {1}}
        """
        super().__init__()
        self._default = default
        self._operator = operator
        self._strict = strict
        if isinstance(iterable, IntervalDict):
            self._mapping = iterable._mapping.copy()
            self._intervals = SortedSet(iterable._intervals)
        else:
            self._mapping = {}
            self._intervals: SortedSet = SortedSet()
            if isinstance(iterable, dict):
                self.update(*({key: value} for key, value in iterable.items()))
            elif iterable is not None:
                self.update(*({
                    key: value
                } for key, value in iterable))  # type: ignore

    def __getitem__(self, key: Union[slice,
                                     atomic.IntervalValue[atomic.TO]]) -> V:
        """
        Return a value using either a slice or an interval value.

        Arguments
        ---------
            key: Union[IntervalValue, slice]
                The interval requested.

        Returns
        -------
            The found value

        Raises
        ------
            KeyError
                If the *key* is out of range.
        """
        try:
            return super().__getitem__(key)
        except KeyError:
            if self._default is not None:
                value = self._default()
                self[key] = value
                return value
            raise

    def __setitem__(self, key: Union[slice, atomic.IntervalValue[atomic.TO]],
                    value: V) -> None:
        """
        Set a value using either a slice or an interval value.

        Arguments
        ---------
            key: Union[IntervalValue, slice]
                The interval requested.

        Raises
        ------
            KeyError
                If the *key* is out of range.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> a = MutableIntervalDict[int, int](
            ...     {(10, 15): 1, (20, 25): 2, (30, 35): 3}
            ... )
            >>> print(a)
            {'[10;15)': 1, '[20;25)': 2, '[30;35)': 3}
            >>> a[12] = 4
            >>> print(a)
            {'[10;12)': 1, '[12;12]': 4, '(12;15)': 1, '[20;25)': 2, '[30;35)': 3}
            >>> a[13:31] = 5
            >>> print(a)
            {'[10;12)': 1, '[12;12]': 4, '(12;13)': 1, '[13;31)': 5, '[31;35)': 3}
            >>> a[:] = 0
            >>> print(a)
            {'(-inf;+inf)': 0}
        """
        interval = self._remove(key)
        if interval:
            self._intervals.add(interval)
            self._mapping[interval] = value

    def __delitem__(
            self, key: Union[slice, atomic.IntervalValue[atomic.TO]]) -> None:
        """
        Delete a value using either a slice or an interval value.

        Arguments
        ---------
            key: Union[IntervalValue, slice]
                The interval requested.

        Raises
        ------
            KeyError
                If the *key* is out of range.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> a = MutableIntervalDict[int, int](
            ...     {(10, 15): 1, (20, 25): 2, (30, 35): 3}
            ... )
            >>> print(a)
            {'[10;15)': 1, '[20;25)': 2, '[30;35)': 3}
            >>> del a[12]
            >>> print(a)
            {'[10;12)': 1, '(12;15)': 1, '[20;25)': 2, '[30;35)': 3}
            >>> del a[13:31]
            >>> print(a)
            {'[10;12)': 1, '(12;13)': 1, '[31;35)': 3}
            >>> del a[:]
            >>> print(a)
            {}
        """
        self._remove(key)

    def _remove(self, key):
        # pylint: disable=protected-access
        interval = IntervalDict._interval(key)
        if interval:
            start = self._start(interval)
            stop = self._stop(interval)
            for index in range(start, stop):
                del self._mapping[self._intervals[index]]
            del self._intervals[start:stop]
        return interval

    def _add(self, interval, value):
        if self._operator is None:
            self[interval] = value
        else:
            intervals = list(self.select(interval, strict=False))
            for another in ((interval & found)[0] for found in intervals):
                self[another] = self._operator(self[another], value)
            for another in sets.FrozenIntervalSet[atomic.TO](
                [interval]) - sets.FrozenIntervalSet[atomic.TO](intervals):
                self[another] = value

    def __or__(self, other) -> "MutableIntervalDict[atomic.TO, V]":
        """
        Construct a new dictionary using self and the *other*.

        Arguments
        ---------
            other: :class:`IntervalDict`
                Another interval dict.

        Returns
        -------
            :class:`MutableIntervalDict`
                The new :class:`IntervalDict`.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> a = MutableIntervalDict[int, int](
            ...     {(10, 15): 1, (20, 25): 2, (30, 35): 3},
            ...     operator=lambda x, y: x + y,
            ...     strict=False
            ... )
            >>> print(a | FrozenIntervalDict[int, int]({(15, 22): 4}))
            {'[10;15)': 1, '[15;20)': 4, '[20;22)': 6, '[22;25)': 2, '[30;35)': 3}
        """
        if not isinstance(other, IntervalDict):
            return NotImplemented
        result = self.__class__(self,
                                default=self._default,
                                operator=self._operator,
                                strict=self._strict)
        result.update(other)
        return result

    def __ior__(self, other) -> "MutableIntervalDict[atomic.TO, V]":
        """
        Update self with the *other*.

        Arguments
        ---------
            other: :class:`IntervalDict`
                Another interval dict.

        Returns
        -------
            :class:`MutableIntervalDict`
                The updated :class:`MutableIntervalDict`.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> a = MutableIntervalDict[int, int](
            ...     operator=lambda x, y: x + y,
            ...     strict=False
            ... )
            >>> a |= MutableIntervalDict[int, int]({(1, 10): 1})
            >>> print(a)
            {'[1;10)': 1}
            >>> a |= MutableIntervalDict[int, int]({(5, 20): 2})
            >>> print(a)
            {'[1;5)': 1, '[5;10)': 3, '[10;20)': 2}
            >>> a |= MutableIntervalDict[int, int]({(10, 30): 3})
            >>> print(a)
            {'[1;5)': 1, '[5;10)': 3, '[10;20)': 5, '[20;30)': 3}
        """
        if not isinstance(other, IntervalDict):
            return NotImplemented
        self.update(other)
        return self

    # pylint: disable=arguments-differ,signature-differs
    def update(  # type: ignore
        self,
        *args: Union[IntervalDict, Mapping[atomic.IntervalValue[atomic.TO], V],
                     Iterable[Tuple[atomic.IntervalValue[atomic.TO], V]], ],
    ) -> None:
        """
        Update the dict.

        Arguments
        ---------
            *args: :class:`Iterable <python:typing.Iterable>`
                An iterable of :class:`IntervalDict` or valid iterable for an interval
                dictionary creation.

        Raises
        ------
            TypeError
                if an argument is not iterable.

        Note
        ----
            If the parameter ``strict`` used in the constructor is :data:`False
            <python:False>`, the complexity is in :math:`O(n\\,log(n)\\,k\\,\\lambda)`
            where:

            * :math:`n` is the length of ``*args``;
            * :math:`k` is the number of output intervals;
            * :math:`\\lambda` is the the cost of the ``operator`` parameter used in
              the constructor.

        Examples
        --------

            >>> from part import MutableIntervalDict
            >>> from operator import add
            >>> a = MutableIntervalDict[int, int](
            ...     operator=add,
            ...     default=lambda: 0,
            ... )
            >>> a.update({(1, 10): 1})
            >>> print(a)
            {'[1;10)': 1}
            >>> a.update(
            ...     FrozenIntervalDict[int, int]({(5, 20): 2}),
            ...     FrozenIntervalDict[int, int]({(10, 30): 3})
            ... )
            >>> print(a)
            {'[1;5)': 1, '[5;10)': 3, '[10;20)': 5, '[20;30)': 3}
            >>> a = MutableIntervalDict[int, set](
            ...     operator=lambda x, y: x | y,
            ...     strict=False
            ... )
            >>> a.update({(1, 10): {1}})
            >>> print(a)
            {'[1;10)': {1}}
            >>> a.update(
            ...     FrozenIntervalDict[int, set]({(5, 20): {2}}),
            ...     FrozenIntervalDict[int, set]({(10, 30): {3}})
            ... )
            >>> print(a)
            {'[1;5)': {1}, '[5;10)': {1, 2}, '[10;20)': {2, 3}, '[20;30)': {3}}
        """
        # TODO determine complexity
        strict = self._strict
        operator = self._operator
        if strict or operator is None:
            self._strict_update(*args)
        else:
            self._enhanced_update(*args)

    def _strict_update(
        self,
        *args: Union[IntervalDict, Mapping[atomic.IntervalValue[atomic.TO], V],
                     Iterable[Tuple[atomic.IntervalValue[atomic.TO], V]], ],
    ) -> None:
        for other in args:
            if isinstance(other, collections.abc.Mapping):
                for key, value in other.items():
                    self._add(atomic.Atomic.from_value(key), value)
            elif isinstance(other, collections.abc.Iterable):
                for key, value in other:  # type: ignore
                    self._add(atomic.Atomic.from_value(key), value)
            else:
                raise TypeError(f"{type(other)} object is not iterable")

    # pylint: disable=protected-access
    @classmethod
    def _rest(cls, rest, cursors, index, element):
        if cursors[index] < len(element):
            interval = element._intervals[cursors[index]]
            value = element._mapping[interval]
            rest.add((interval.lower, interval.upper, index, value))

    # pylint: disable=protected-access
    def _create(self, *args):
        # Create a list of non empty IntervalDict
        elements = []
        for element in itertools.chain([self], args):
            if not isinstance(element, IntervalDict):
                element = FrozenIntervalDict(element)
            if element:
                elements.append(element)

        cursors = [0] * len(elements)

        current = SortedList()

        rest = SortedList()
        for index, element in enumerate(elements):
            self._rest(rest, cursors, index, element)

        return (elements, cursors, current, rest)

    # pylint: disable=too-many-arguments,protected-access
    @classmethod
    def _next(cls, upper, elements, cursors, current, rest):

        # Remove useless elements from current
        while current and current[0][0] == upper:
            (_, _, index, _) = current[0]
            del current[0]
            cursors[index] += 1
            element = elements[index]
            cls._rest(rest, cursors, index, element)

        if current or not rest:
            lower = upper.next()
        else:
            lower = rest[0][0]

        # Move elements from rest to current
        while rest and rest[0][0] == lower:
            (lower, upper, index, value) = rest[0]
            del rest[0]
            current.add((upper, lower, index, value))

        if current:
            upper = current[0][0]
        if rest:
            upper = min(upper, rest[0][0].prev())

        return (lower, upper)

    def _enhanced_update(
        self,
        *args: Union[IntervalDict, Mapping[atomic.IntervalValue[atomic.TO], V],
                     Iterable[Tuple[atomic.IntervalValue[atomic.TO], V]], ],
    ) -> None:
        intervals = []
        mapping = {}

        (elements, cursors, current, rest) = self._create(*args)
        (lower, upper) = self._next(-atomic.INFINITY, elements, cursors,
                                    current, rest)

        while current:
            interval = atomic.Interval[atomic.TO](
                lower_value=lower.value,
                lower_closed=lower.type == 0,
                upper_value=upper.value,
                upper_closed=upper.type == 0,
            )
            value = reduce(
                self._operator,
                (value for (_, _, _, value) in current)  # type: ignore
            )
            intervals.append(interval)
            mapping[interval] = value
            (lower, upper) = self._next(upper, elements, cursors, current,
                                        rest)

        self._intervals = SortedSet(intervals)
        self._mapping = mapping

    def clear(self) -> None:
        """Remove all items from self (same as del self[:])."""
        self._intervals = SortedSet()
        self._mapping = {}

    def _start(self, interval):
        start = self._intervals.bisect_left(interval)
        if start < len(self._intervals):
            start_interval = self._intervals[start]
            start_value = self._mapping[start_interval]
            if self._intervals[start].overlaps(interval, strict=False):
                del self._intervals[start]
                del self._mapping[start_interval]
                if self._insert(
                    (
                        start_interval.lower_value,
                        interval.lower_value,
                        start_interval.lower_closed,
                        not interval.lower_closed,
                    ),
                        start_value,
                ):
                    start += 1
            elif interval.overlaps(self._intervals[start], strict=False):
                del self._intervals[start]
                del self._mapping[start_interval]
                if self._insert(
                    (
                        interval.upper_value,
                        start_interval.upper_value,
                        not interval.upper_closed,
                        start_interval.upper_closed,
                    ),
                        start_value,
                ):
                    start += 1
            elif interval.during(self._intervals[start], strict=False):
                del self._intervals[start]
                del self._mapping[start_interval]
                if self._insert(
                    (
                        start_interval.lower_value,
                        interval.lower_value,
                        start_interval.lower_closed,
                        not interval.lower_closed,
                    ),
                        start_value,
                ):
                    start += 1
                if self._insert(
                    (
                        interval.upper_value,
                        start_interval.upper_value,
                        not interval.upper_closed,
                        start_interval.upper_closed,
                    ),
                        start_value,
                ):
                    start += 1
        return start

    def _stop(self, interval):
        stop = self._intervals.bisect_right(interval)
        if 0 < stop <= len(self._intervals):
            stop -= 1
            stop_interval = self._intervals[stop]
            stop_value = self._mapping[stop_interval]
            if self._intervals[stop].during(interval, strict=False):
                del self._intervals[stop]
                del self._mapping[stop_interval]
            elif interval.overlaps(self._intervals[stop], strict=False):
                del self._intervals[stop]
                del self._mapping[stop_interval]
                self._insert(
                    (
                        interval.upper_value,
                        stop_interval.upper_value,
                        not interval.upper_closed,
                        stop_interval.upper_closed,
                    ),
                    stop_value,
                )

        return stop

    def _insert(self, key, value):
        interval = atomic.Atomic.from_tuple(key)
        if interval:
            self._intervals.add(interval)
            self._mapping[interval] = value
            return True
        return False

    def _bisect_left(self, search) -> int:
        return self._intervals.bisect_left(search)

    def _update_intervals(self, intervals) -> None:
        self._intervals = SortedSet(intervals)
コード例 #26
0
class BindingSites:
    """
    BindingSites allows adding binding sites, removing binding sites, querying
    binding sites given interval ranges of interest, measuring "correlations"
    between pairs of sets of binding sites, printing BED files, and other
    related functions.

    BindingSites can be used in two modes:
        1. when overlap_mode is on, sites are allowed to overlap. This allows
            for queries into "depths" of binding (i.e. how many proteins are
            binding together in a given interval on the RNA).
        2. when overlap_mode is off, sites that overlap are always merged
            together, with no "depth" information maintained. The identity and
            description of binding sites to be merged are also merged and
            associated with the merged binding site

    Note that this project concerns itself with RNAs and proteins, but this
    class could very well be used in other similar scenarios (like proteins
    interacting with DNA).

    The underlying implementation uses an ordered set to keep track of the
    ranges, while adding a range to the set will check for overlaps and deal
    with them.

    This may have been unnecessary, as there are O(nlogn) algorithms that
    produce merged intervals given a list of ranges that can be found online.
    For example:
    https://codereview.stackexchange.com/
                                questions/69242/merging-overlapping-intervals

    However I have implemented this now so I will keep it. However, this may
    allow for simple dynamic additions and deletions with O(logn) each time.

    Major modifications note July 21st 2019:
    Most of the functions defined here have an implicit prerequisite that the
    binding site intervals are non overlapping. As a result,
    the BindingSites.add() function makes sure that the overlapping intervals
    are joined together to form a (possibly) larger interval that includes both
    of the intervals.

    However, now we see that some of the input data might be from experimental
    sources where overlaps are very important (i.e. they represent confidence of
    binding regions because more sequences were identified from that particular
    region). As a result, I am adding a second class variable that keeps track
    of the raw binding sites, so that more functionality can be supported, such
    as finding out the "depth of support" for RBPs binding to a specific
    nucleotide, as well as filters for collapsing the overlapping region based
    on a criteria (e.g. only those with support depth 5 or above, etc.)

    Therefore, as of now, there are two main supported ways of using
    BindingSites:
        1. You add intervals and let BindingSites dynamically take care of
            overlapping intervals for you. Note that as of now, this is a
            default behaviour that always occurs.
        2. You initialize BindingSites with a overlap_mode = True.
            You then add intervals that are highly overlapping and, once
            complete, call the overlap_collapse() function to collapse all the
            intervals as per your specifications. This sets the overlap_mode to
            off. Following this, the representation of the BindingSites to the
            client magically changes to the non-overlapping counterparts and
            enables the functions previously disabled as overlap_mode was
            switched on.

    If overlap_collapse() is called too soon, everything has to be loaded again
    fresh.
    """

    def __init__(self, list_of_sites=None, overlap_mode=False):
        if list_of_sites is None:
            list_of_sites = []
        self.overlap_mode = overlap_mode
        # Just a sorted set underneath
        self.sorted_sites = SortedSet()
        for site in list_of_sites:
            self.add(site)

    def __repr__(self, display_meta=False):
        """Representation of BindingSites objects.

        Show all three elements (start, end, metadata) optionally by setting
        dispMeta = True or alternatively just show the first two tuple elements
        for succinctness
        """
        overlap_add = "OverlapOn" if self.overlap_mode else ""
        if display_meta:
            return (
                self.sorted_sites.__repr__()
                    .replace("SortedSet", "BindingSites" + overlap_add)
            )

        return (
            SortedSet(map(firstTwoItems, self.sorted_sites))
            .__repr__().replace("SortedSet", "BindingSites" + overlap_add)
        )

    def __str__(self):
        return self.__repr__(display_meta=True)

    def __len__(self):
        return self.sorted_sites.__len__()

    def __iter__(self):
        return self.sorted_sites.__iter__()

    def __getitem__(self, item):
        # Allow for slice selections of elements in BindingSites
        if isinstance(item, slice):
            return BindingSites(self.sorted_sites[item])
        return self.sorted_sites[item]

    @staticmethod
    def is_overlap_ranges(interval_1, interval_2):
        """True iff the ranges (intervals) p and q overlap

        :param p: an interval in the form (start, end) or (start, end, metadata)
        :param q: an interval in the form (start, end) or (start, end, metadata)

        """

        start_1, end_1, *_ = interval_1
        start_2, end_2, *_ = interval_2
        return start_1 <= end_2 and start_2 <= end_1

    @staticmethod
    def _merge_meta(annotation_list, user_merge_func=None):
        """
        Internal function for merging the annotations of multiple
        binding site ranges. The annotations are merged into a tuple of
        annotatins. Some of the annotations passed in may be tuples or lists of
        annotations, too.

        :param annotation_list: A list of annotations to be merged
        :param user_merge_func: if specified, this function is used to merge the
                                metadata instead (Default value = None)

        """

        # TODO: fix the poor style of this function

        # Get all the annotations first from the input list
        new_l = []
        for element in annotation_list:
            if element is None:
                continue
            if isinstance(element, tuple):
                for term in element:
                    new_l.append(term)
            else:
                new_l.append(element)

        if len(new_l) == 0:
            return None

        if len(new_l) == 1:
            return new_l[0]

        # Use user-defined function to merge the list of
        # annotations
        if user_merge_func is not None:
            return user_merge_func(new_l)

        # Otherwise, make a tuple of it, unless its just one element
        new_l_set = set(new_l)

        assert len(new_l_set) > 1

        if len(new_l_set) > 1:
            return tuple(new_l_set)

        return new_l_set.pop()

    @staticmethod
    def _collapse(interval_list, user_merge_func=None):
        """Takes a list of overlapping ranges and collapses them into one

        :param interval_list: A list of intervals to be merged
        :param user_merge_func: if specified, this function is used to merge the
                                metadata (otherwise, metadata is simply joined
                                into a tuple) (Default value = None)

        """

        # assert(not self.overlap_mode)

        if OVERLAP_CONFLICT == "union":
            to_return = (min(interval_list, key=firstItem)[0],
                         max(interval_list, key=secondItem)[1],
                         BindingSites._merge_meta(
                            list(map(thirdItem, interval_list)),
                            user_merge_func)
                        )

        elif OVERLAP_CONFLICT == "intersect":
            to_return = (max(interval_list, key=firstItem)[0],
                         min(interval_list, key=secondItem)[1],
                         BindingSites._merge_meta(
                            list(map(thirdItem, interval_list)),
                            user_merge_func)
                        )

        return to_return


    def add(self, new_site, user_merge_func=None):
        """Dynamic addition of a range to a sorted set of non-overlapping ranges
        while maintaining the sorted property and merging any produced overlaps.

        May not be the most efficent way of doing this, as _collapse function
        does not take advantage of the sortedness of the ranges.

        :param new_site: the new site to be added, in the form
                         (start, end, metadata) or (start, end).

        :param user_merge_func: If specified, this function is used to merge
                                the metadata of the binding sites in case of
                                overlaps (and a need to merge the sites, due to
                                overlap_mode being off).
                                This function should expect a list of
                                annotations and return one.
                                Otherwise, the default behaviour collates the
                                annotations into a tuple
                                (Default value = None)

        """

        if len(new_site) == 2:
            start, end = new_site
            new_site = (start, end, None)
        elif len(new_site) != 3:
            raise ValueError("Please keep three values in the tuple: " +
                             "(start, end, annotation)")

        start, end, _ = new_site
        if not isinstance(start, int) or not isinstance(end, int):
            raise ValueError("Please make sure start and end are integers")

        if start > end:
            raise ValueError(
                "Please make sure the interval end point is greater than the"
                " start point!")

        if self.overlap_mode:
            self.sorted_sites.add(new_site)
            return

        # binary search to find where the new range lies
        start_pos = self.sorted_sites.bisect_left((start, 0))
        end_pos = self.sorted_sites.bisect_left((end, 0))

        # initiate list of ranges that might be merged
        to_merge = [new_site]

        # indices of the sorted set to look at that have the potential for
        # overlapping
        lower = max(0, start_pos - 1)
        higher = min(end_pos + 1, len(self.sorted_sites))

        # This part could be O(n) theoretically but experimentally,
        # (higher-lower) is always strictly less than 5 for this data
        for site in self.sorted_sites[lower:higher]:
            if BindingSites.is_overlap_ranges(site, new_site):
                self.sorted_sites.remove(site)
                to_merge.append(site)

        self.sorted_sites.add(BindingSites._collapse(to_merge, user_merge_func))

    def remove(self, site):
        """
        Removes a binding site from an instance of BindingClass.
        :param site: a site to be removed from the BindingClass object. This
                     should exactly match the site that is in the object
                     (if metadata was specified, it should be specified too).
                     This is easier done by obtaining the site using one of the
                     query methods.

        """
        self.sorted_sites.remove(site)

    def dist(self, sites, bp_threshold=30):
        """Checks for correlation between binding sites of two BindingSites.
        Returns a value from 0 to 1.

        WARNING: a.dist(b) and b.dist(a) can give VERY different answers!
        This is because this function checks the "distances" of each of the
        binding sites of one set of BindingSites with all of the other set
        to count the minimal distance and give a scoring based on that.
        Since one of the set of BindingSites could be ubiquitous, the scores
        may vary greatly.

        :param sites: a BindingSites object to compare with.
        :param bp_threshold: the threshold distance value beyond which a site
                             is considered too far from another. The score of
                             proximity is given based on this cutoff value.
                             (Default value = 30)

        """

        if self.overlap_mode:
            raise ValueError("dist() is not supported for BindingSites with"
                             " overlap_mode set to True")


        # Todo: break this function into two: one that accepts one site, and
        # another that accepts a BindingSite
        if isinstance(sites, tuple): # only one tuple input
            start = sites[0]
            end = sites[1]
            pos = self.sorted_sites.bisect_left((start, 0))

            # tuple at the beginning
            if pos == 0:
                dist_end = max(0, self.sorted_sites[pos][0] - end)
                return max(0, 1 - dist_end / bp_threshold)

            # tuple at the end
            if pos == len(self.sorted_sites):
                dist_start = max(0, start - self.sorted_sites[pos - 1][1])
                return max(0, 1 - dist_start / bp_threshold)

            # tuple in the middle
            dist_start = max(0, start - self.sorted_sites[pos - 1][1])
            dist_end = max(0, self.sorted_sites[pos][0] - end)
                # return the closer distance
            return max(0, 1 - min(dist_start, dist_end) / bp_threshold)

        if isinstance(sites, BindingSites): # a set of tuples given
            cum = 0
            for site in sites:
                cum += self.dist(site, bp_threshold)

            return cum / len(sites)

        # non-supported input type
        print("sites is of type", type(sites))
        raise ValueError("Unsupported type for sites, should be a tuple" +
                            " or BindingSites")


    def is_overlap(self, query_site):
        """This checks if an input tuple range overlaps one of those present in
        the set of binding sites stored in self and returns a bool to indicate
        the result.

        :param query_site: a site in the form (start, end, metadata) or just
                            (start, end)

        """

        if self.overlap_mode:
            raise ValueError("isOverlap() is not supported for BindingSites"
                             " with overlap_mode set to True")

        start, end, *_ = query_site

        # binary search to find where the query range might lie
        start_pos = self.sorted_sites.bisect_left((start, 0))
        end_pos = self.sorted_sites.bisect_left((end, 0))

        # indices of the sorted set to look at that have the potential for
        # overlapping
        lower = max(0, start_pos - 1)
        higher = min(end_pos + 1, len(self.sorted_sites))

        for site in self.sorted_sites[lower:higher]:
            if BindingSites.is_overlap_ranges(site, query_site):
                return True
        return False

    def nearest_site(self, query_site):
        """This returns the closest range to the input tuple range
        present in the set of binding sites stored in self

        :param query_site: a site in the form (start, end, metadata) or just
                            (start, end)

        """

        if self.overlap_mode:
            raise ValueError("nearestSite() is not supported for BindingSites"
                             " with overlap_mode set to True")

        start, end, *_ = query_site
        pos = self.sorted_sites.bisect_left((start, 0))

        # tuple at the beginning
        if pos == 0:
            dist_end = self.sorted_sites[pos][0] - end
            return self.sorted_sites[pos], max(0, dist_end)

        # tuple at the end
        if pos == len(self.sorted_sites):
            dist_start = start - self.sorted_sites[pos - 1][1]
            return self.sorted_sites[pos - 1], max(0, dist_start)

        # tuple in the middle
        dist_start = start - self.sorted_sites[pos - 1][1]
        dist_end = self.sorted_sites[pos][0] - end

        if dist_start > dist_end:
            return self.sorted_sites[pos], max(0, dist_end)

        return self.sorted_sites[pos - 1], max(0, dist_start)

    @staticmethod
    def distance(site_1, site_2):
        """
        Returns the distance between two interval ranges. If they overlap, the
        distance is zero.

        :param site_1: a site in the form (start, end, metadata) or just
                            (start, end)
        :param site_2: a site in the form (start, end, metadata) or just
                            (start, end)

        """
        if BindingSites.is_overlap_ranges(site_1, site_2):
            return 0

        start_1, end_1, *_ = site_1
        start_2, end_2, *_ = site_2
        return min(abs(start_1 - end_2), abs(start_2 - end_1))

    def print(self):
        """
        prints the BindingSite instance
        """
        print(self)

    def len(self):
        """
        returns the number of sites stored in the BindingSite instance
        """
        return len(self)

    def filter_overlap(self, query_range, bp_threshold=0):
        """
        Given a query range, returns a new BindingSite instance with only sites
        that overlap the given range.

        :param query_range: a range/site in the form (start, end) or
                            (start, end, metadata)
        :param bp_threshold:  (Default value = 0)

        """
        if self.overlap_mode:
            raise ValueError("filterOverlap() is not supported for BindingSites"
                             " with overlap_mode set to True")

        # Returns all the sites which overlap with the input range query_range
        # given
        start, end, *_ = query_range
        start = start - bp_threshold
        end = end + bp_threshold
        query_range = (start, end)
        start_pos = self.sorted_sites.bisect_left((start, 0))
        end_pos = self.sorted_sites.bisect_left((end, 0))

        # indices of the sorted set to look at that have the potential for
        # overlapping
        lower = max(0, start_pos - 1)
        higher = min(end_pos + 1, len(self.sorted_sites))

        output_binding_sites = BindingSites()
        for site in self.sorted_sites[lower:higher]:
            if BindingSites.is_overlap_ranges(site, query_range):
                output_binding_sites.add(site)
        return output_binding_sites


    def print_bed(self, name="Generic Binding Site", chr_n=1, displacement=0,
                 end_inclusion=False, add_annotation=False, include_score=False,
                 score_max=1000, score_base=1000, include_color=False,
                 conditional_color_func=None, is_bar=False,
                 is_additional_columns=False,
                 annotation_to_additional_columns=None):
        """

        :param name:  (Default value = "Generic Binding Site")
        :param chr_n:  (Default value = 1)
        :param displacement:  (Default value = 0)
        :param end_inclusion:  (Default value = False)
        :param add_annotation:  (Default value = False)
        :param include_score:  (Default value = False)
        :param score_max:  (Default value = 1000)
        :param score_base:  (Default value = 1000)
        :param include_color:  (Default value = False)
        :param conditional_color_func:  (Default value = None)
        :param is_bar:  (Default value = False)
        :param is_additional_columns:  (Default value = False)
        :param annotation_to_additional_columns:  (Default value = None)

        """

        # Todo: investigate the reason for add_annotation being an unused
        # argument
        del add_annotation

        # Todo: Possibly remove the functionality for is_bar, it seems
        # misplaced!
        output_str = ""
        if not isinstance(chr_n, str):
            chr_n = "chr" + str(chr_n)
        else:
            chr_n = ("chr" + chr_n) if chr_n[:3] != "chr" else chr_n

        for _tuple in self.sorted_sites:
            start, end, annotation = _tuple
            start, end = displacement + start, (displacement + end +
                                                (1 if end_inclusion else 0))

            name_display = name
            to_join = [chr_n, start, end, name_display]

            if include_color and not include_score:
                score = 1000
                to_join.append(score)
            elif include_score:

                assert len(annotation) == 1

                score = float("".join(filter(
                    lambda k: k.isdigit() or k == "." or k == "-",
                    annotation[0])))
                score = int(score / score_base * score_max)
                to_join.append(score)

            if include_color and is_bar:
                raise ValueError("Cant be both color and bar!")

            if include_color:
                strand = "+"  # default

                if conditional_color_func is None:
                    color = "0,0,0"  # black
                else:
                    red, green, blue = conditional_color_func(_tuple)
                    color = ','.join(map(str, [red, green, blue]))

                to_join += [strand, start, end, color]

            if is_bar and not include_score:
                raise ValueError("What height for bar?")
            if is_bar:
                strand = "+"  # default
                number_of_bars = 1
                to_join += [strand, name, number_of_bars, score]

            if is_additional_columns:
                to_join += (
                    [s.replace(" ", "_") if s else ".'"
                    for s in annotation_to_additional_columns(annotation)]
                )
            output_str += "\t".join(map(str, to_join)) + "\n"

        return output_str


    def return_depth(self, length=-1):
        """
        Returns the density array of binding sites stored.

        Note that binding sites can bind on the 0th nucleotide, and cannot bind
        on the length'th nucleotide

        :param length: a parameter guaranteed to be bigger than the end point
        of any site stored in the BindingSites instance. Specify -1 if unknown.
        (Default value = -1)

        """

        if length == -1:
            if len(self) == 0:
                raise ValueError("If the BindingSites object is empty, please"
                                 " do not call return_depth without "
                                 "specifying the length parameter.")
            length = max(map(secondItem, self)) + 1

        # Stores 'depth' of support for each nucleotide in the
        # molecule in terms of its chances of being a binding site.
        binding_depth = [0] * length

        for site in self:

            start = site[0]
            end = site[1]

            for nucleotide in range(start, end + 1):  # inclusive
                binding_depth[nucleotide] += 1

        return binding_depth

    def overlap_collapse(self, mode, number, in_place=False,
                                                        annotation_merger=None):
        """Collapses the overlapping ranges to non-overlapping ones, based on
        preset conditions.

        This function will always look at the 'depths' of how much coverage
        support each nucleotide position has and chooses a cutoff point - e.g.
        all nucleotides above depth level of 5 is kept as binding sites and the
        rest are discarded. The cutoff point can  be chosen through multiple
        means.

        The modes supported right now are 'baseCoverNumber', 'TopDepthRatio',
        'TopDepthNumber', 'MinimumDepthNumber', 'TopSitesNumber','TopSitesRatio'

            'baseCoverNumber': Choose cut off based on number of bases that
                should be covered by the selected sites. The most stringent
                cutoff that achieves this criteria is selected, unless not
                possible*.

            'TopDepthRatio': Choose cutoff based on the fraction of highest
                depth coverage that should be supported as binding sites. For
                example, if the deepest coverage provided is 10 and number=0.4,
                then depth coverage of 10,9,8,7,and 6 is counted as binding
                sites and the rest are disregarded.

            'TopDepthNumber': The number of layers of depth from the highest
                depth support that should be selected is input, and the cutoff
                is accordingly selected.

            'MinimumDepthNumber': The cut-off is selected based on the minimum
                depth support each binding site should have.

            'TopSitesNumber': The cut-off is selected such that the top selected
                number of binding sites remains supported. For example, the top
                100 sites may be preserved (from a set of, say, 1000 overlapping
                sites)

            'TopSitesRatio': The cut-off is selected much like above, but the
                ratio of top sites that should be selected is specified instead.
                For example, in the above example, 0.1 could be specified
                instead.

        As of now, calling overlap_collapse() loses all annotation data
        associated with the original range interval data.

        If inPlace is set to True, the BindingSites variable changes and
        collpases, otherwise a new BindingSites variable is generated and
        returned.

        *In general, no nucleotide with support<1 is kept.

        :param mode: a string representing the mode to be used
        :param number: an appropriate numeric argument based on the mode
                       specified
        :param in_place: if True, modifies the instance on which this function
                         is called (and sets overlap_mode to False).
                         Otherwise, creates a new BindingSites instance and
                         returns it (Default value = False)
        :param annotation_merger: If specified, this argument (a function) is
                                  used to merge the annotations of the
                                  overlapping binding sites. Otherwise, the
                                  annotations are merged into a tuple
                                  (Default value = None)

        """
        if not self.overlap_mode:
            print("WARNING: overlap_collapse() called although overlap_mode is"
                  " set to off!")

        depth_array = self.return_depth()

        max_depth = max(depth_array)

        if mode == 'baseCoverNumber':
            depth_cutoff = -1
            while len(list(filter(
                            lambda k: k > depth_cutoff, depth_array)
                    )) > number:
                depth_cutoff += 1

            if depth_cutoff == -1:
                # print("WARNING: your baseCoverNumber is impossible to"
                #       " achieve!")
                pass

        elif mode == 'TopDepthRatio':
            if not 0 <= number <= 1:
                raise ValueError("Ratio should be between 0 and 1")

            depth_cutoff = max_depth * (1 - number)

        elif mode == 'TopDepthNumber':
            depth_cutoff = max_depth - number

        elif mode == 'MinimumDepthNumber':
            depth_cutoff = number - 1

        elif mode == 'TopSitesNumber':
            raise ValueError("Unimplemented function")
        elif mode == 'TopSitesRatio':
            raise ValueError("Unimplemented function")
        else:
            raise ValueError("The mode selected, '" + mode + "' is not"
                             " supported!")

        sites = self.sorted_sites
        depth_cutoff = max(0, depth_cutoff)
        if in_place:
            self.overlap_mode = False
            self.sorted_sites = SortedSet()
            binding_site_to_add_to = self
        else:
            binding_site_to_add_to = BindingSites()

        in_range = False
        start_range = 0
        end_range = 0
        nucleotide = 0
        for nucleotide, depth in enumerate(depth_array):
            if depth > depth_cutoff:
                if not in_range:
                    start_range = nucleotide
                    in_range = True
            else:
                if in_range:
                    end_range = nucleotide - 1
                    in_range = False
                    binding_site_to_add_to.add(
                        (start_range, end_range))

        if in_range:
            end_range = nucleotide
            in_range = False
            binding_site_to_add_to.add(
                (start_range, end_range)
            )

        # Add annotations
        for site in sites:
            start, end, annotation = site
            # print(start, end, annotation)
            site, distance = binding_site_to_add_to.nearest_site(site)
            if distance == 0:
                binding_site_to_add_to.add(
                        (start, end, annotation), annotation_merger
                )
        if not in_place:
            return binding_site_to_add_to

        return BindingSites()

    def base_cover(self):
        """
        Returns the number of bases of the RNA molecule that have at least one
        protein bound to it

        (or simply, the length of space covered by the intervals stored in
        BindingSites)
        """
        depth_array = self.return_depth()
        return len(list(filter(lambda k: k > 0, depth_array)))

    def print_wig(self, chr_no=1, displacement=0, include_name=False,
                  include_description=False, name="",
                  description="", include_header=True, length=-1):
        """Prints a wig file depicting density of binding sites by the RBP.

        Optional parameter length allows for plotting 0 beyond the rightmost
        binding site if needed.

        :param chr_no:  (Default value = 1)
        :param displacement:  (Default value = 0)
        :param include_name:  (Default value = False)
        :param include_description:  (Default value = False)
        :param name:  (Default value = "")
        :param description:  (Default value = "")
        :param include_header:  (Default value = True)
        :param length:  (Default value = -1)

        """

        output_str = ""
        if include_header:
            output_str += "track type=wiggle_0 "
            if include_name:
                output_str += 'name="' + name + '" '
            if include_description:
                output_str += 'description="' + description + '" '
            output_str += "visibility=full"
            output_str += '\n'

        # Note the +1 below. I suspect this is necessary as wig files are
        # 1-based...
        output_str += ("fixedStep chrom=chr" + str(chr_no) + " start="
                      + str(displacement + 1) + " step=1")

        output_str += "\n"

        # length long array, 0-indexed
        depth_array = self.return_depth(length=length)
        output_str += "\n".join(map(str, depth_array))

        return output_str
コード例 #27
0
一款纯python写的对列表、字典、集合排序的模块
下面显示的所有操作都比线性时间快

>>> from sortedcontainers import SortedList
>>> sl = SortedList(['e', 'a', 'c', 'd', 'b'])
>>> sl
SortedList(['a', 'b', 'c', 'd', 'e'])
>>> sl *= 10_000_000
>>> sl.count('c')
10000000
>>> sl[-3:]
['e', 'e', 'e']

>>> from sortedcontainers import SortedDict
>>> sd = SortedDict({'c': 3, 'a': 1, 'b': 2})
>>> sd
SortedDict({'a': 1, 'b': 2, 'c': 3})
>>> sd.popitem(index=-1)
('c', 3)

>>> from sortedcontainers import SortedSet
>>> ss = SortedSet('abracadabra')
>>> ss
SortedSet(['a', 'b', 'c', 'd', 'r'])
>>> ss.bisect_left('c')
2
コード例 #28
0
class Intervals:

  def __init__(self, xl=[], regions=None, use_range_data=False, merge=1, is_int=None):
    self.xl = SortedSet()
    self.use_range_data = use_range_data
    self.merge=  merge
    self.is_int = is_int
    if regions is not None:
      xl = [r.getRegion() for r in regions]
    else:
      xl = to_list(xl)
      if len(xl) > 0 and not isinstance(xl[0], list) and not isinstance(xl[0], tuple):
        xl = [xl]

    for x in xl:
      self.add(x)
    glog.debug('Intervals >> %s', self.xl)

  def add(self, *args, **kwargs):
    kwargs = dict(kwargs)
    if self.is_int is not None:
      kwargs['is_int'] = self.is_int

    if self.use_range_data:
      elem = Range1DWithData(*args, **kwargs)
    else:
      elem = Range1D(*args, **kwargs)

    if elem.empty: return
    pos = self.xl.bisect_left(elem)
    cur = None
    next_pos = pos + 1
    if pos > 0:
      prev = self.xl[pos - 1]
      if self.merge and prev.contains(elem.low - 1):
        self.xl.pop(pos - 1)

        ne = prev.union(elem)
        self.xl.add(ne)
        cur = ne
        next_pos = pos
      else:
        assert elem.low >= prev.high

    if cur is None:
      cur = elem
      self.xl.add(elem)

    next = None
    assert cur == self.xl.pop(next_pos - 1)
    next_pos -= 1
    while next_pos < len(self.xl):
      next = self.xl[next_pos]
      if cur.high < next.low:
        break
      self.xl.pop(next_pos)
      cur = cur.union(next)
      cur.high = max(cur.high, next.high)
    self.xl.add(cur)

    return self

  def filter_dataset(self, dataset):
    nx = []
    ny = []
    for i in range(dataset.n):
      if self.should_keep(dataset.x[i]):
        nx.append(dataset.x[i])
        ny.append(dataset.y[i])
    return dataset.make_new('filter', y=ny, x=nx)

  def should_keep(self, v):
    if self.xl is None:
      return True
    for xlow, xhigh in self.xl:
      if xlow <= v < xhigh:
        return True
    return False

  def split_data(self, dataset):
    res = []
    for v in self.xl:
      res.append(dataset[max(0, v.low):v.high])
    return res

  @staticmethod
  def FromIndices(vals, merge_dist=1):
    last = None
    res = Intervals()
    for x in vals:
      if last is None or last.high + merge_dist < x:
        if last is not None:
          res.xl.add(last)
        last = Range1D(x, x, is_int=1)
      else:
        print('mergin here', x, last.low)
      last.high = x
    if last is not None:
      res.xl.add(last)

    return res

  def complement(self, superset):
    superset = Range1D(superset, is_int=1)
    cur = Range1D(superset.low, 0, is_int=1)
    res = Intervals()

    for e in list(self.xl) + [Range1D(superset.high, 0, is_int=1)]:
      cur.high = e.low - 1
      res.xl.add(cur.clone())
      cur.low = e.high + 1
    return res

  def shorten(self, val):
    res = Intervals()
    for x in self.xl:
      res.add(Range1D(x.low + val, x.high - val, is_int=1))
    return res

  def expand(self, val):
    res = Intervals()
    for x in self.xl:
      res.add(Range1D(x.low - val, x.high + val, is_int=1))
    return res

  def filter(self, func):
    res = Intervals()
    for x in self.xl:
      if func(x):
        res.add(x)
    return res

  def shift(self, p):
    res = Intervals()
    for x in self.xl:
      res.add(Range1D(x.low + p, x.high + p))
    return res

  def query(self, q, closest=0):
    pos = self.xl.bisect_left(Range1D(q, math.inf))
    if pos == 0: return None
    if closest or q <= self.xl[pos - 1].high: return self.xl[pos - 1]
    return None

  def query_data_do(self, q, action, fail_if_not=0):
    obj = self.query(q)

    assert obj is not None or not fail_if_not, hex(q)
    if obj is None:
      return None
    return action(obj.data, q - obj.low)

  def query_data_raw(self, q, **kwargs):
    q = self.query(q, **kwargs)
    if q is None: return None
    return q.data

    return found_range.get_data(qr)

  def query_data(self, qr):
    found_range = self.query(qr.low)
    if found_range is None: return bytes([0] * (qr.length() + 1))
    return found_range.get_data(qr)

  def get_ordered_ranges(self):
    return list(self.xl)

  def intersection(self, other):

    prim_ranges = get_primitive_ranges(self.get_ordered_ranges(), other.get_ordered_ranges())
    res = Intervals()
    for e in prim_ranges:
      if self.query(e.low) and other.query(e.low):
        res.add(e)
    return res

  def __str__(self):
    s = io.StringIO()
    s.write('Interval:\n')
    for e in self.xl:
      s.write(str(e) + '\n')
    res = s.getvalue()
    s.close()
    return res

  def group_by(self, tb, res=defaultdict(list)):
    res = copy(res)
    for pos, data in tb:
      res[self.query(pos)].append(data)
    return res
コード例 #29
0
def test5():
    """
    有序的集合:SortedSet
    网址:http://www.grantjenks.com/docs/sortedcontainers/sortedset.html
    """
    from sortedcontainers import SortedSet
    # 创建 SortedSet
    ss = SortedSet([3, 1, 2, 5, 4])
    print(ss)  # SortedSet([1, 2, 3, 4, 5])
    from operator import neg
    ss1 = SortedSet([3, 1, 2, 5, 4], neg)
    print(ss1)  # SortedSet([5, 4, 3, 2, 1], key=<built-in function neg>)
    # SortedSet 转为 list/tuple/set
    print(list(ss))  # SortedSet转为list    [1, 2, 3, 4, 5]
    print(tuple(ss))  # SortedSet转为tuple    (1, 2, 3, 4, 5)
    print(set(ss))  # SortedSet转为set    {1, 2, 3, 4, 5}
    # 插入、删除元素
    ss.discard(-1)  # 删除不存在的元素不报错
    ss.remove(1)  # 删除不存在的元素报错, KeyError
    ss.discard(3)  # SortedSet([1, 2, 4, 5])
    ss.add(-10)  # SortedSet([-10, 1, 2, 4, 5])
    # 返回第一个和最后一个元素
    print(ss[0])  # -10
    print(ss[-1])  # 5
    # 遍历 set
    for e in ss:
        print(e, end=", ")  # -10, 2, 4, 5,
    print()
    # set 中判断某元素是否存在
    print(2 in ss)  # True
    # bisect_left() / bisect_right()
    print(ss.bisect_left(4))  # 返回大于等于4的最小元素对应的下标    2
    print(ss.bisect_right(4))  # 返回大于4的最小元素对应的下标    3
    # 清空 set
    ss.clear()
    print(len(ss))  # 0
    print(len(ss) == 0)  # True
    """
    无序的集合: set
    """
    # 集合的定义:集合是不可变的,因此集合中元素不能是list
    A = {"hi", 2, ("we", 24)}
    B = set()  # 空集合的定义,不能使用B = {}定义集合,这样是字典的定义
    # 集合间的操作, 下面的运算法符都可以写成 op= 的形式
    print("---------------------------------------")
    S = {1, 2, 3}
    T = {3, 4, 5}
    print(S & T)  # 交集,返回一个新集合,包括同时在集合S和T中的元素
    print(S | T)  # 并集,返回一个新集合,包括在集合S和T中的所有元素
    print(S - T)  # 差集,返回一个新集合,包括在集合S但不在T中的元素
    print(S ^ T)  # 补集,返回一个新集合,包括集合S和T中的非相同元素
    # 集合的包含关系
    print("---------------------------------------")
    C = {1, 2}
    D = {1, 2}
    print(C <= D)  # C是否是D的子集  True
    print(C < D)  # C是否是D的真子集  False
    print(C >= D)  # D是否是C的子集  True
    print(C > D)  # D是否是C的真子集  False
    # 集合的处理方法
    print("---------------------------------------")
    S = {1, 2, 3, 5, 6}
    S.add(4)  # 如果x不在集合S中,将x增加到S
    S.discard(1)  # 移除S中元素x,如果x不在集合S中,不报错
    S.remove(2)  # 移除S中元素x,如果x不在集合S中,产生KeyError异常
    for e in S:  # 遍历
        print(e, end=",")
    print()
    print(S.pop())  # 从S中随机弹出一个元素,S长度减1,若S为空产生KeyError异常
    print(S.copy())  # 返回集合S的一个副本, 对该副本的操作不会影响S
    print(len(S))  # 返回集合S的元素个数
    print(5 in S)  # 判断S中元素x, x在集合S中,返回True,否则返回False
    print(5 not in S)  # 判断S中元素x, x在集合S中,返回True,否则返回False
    S.clear()  # 移除S中所有元素
コード例 #30
0
"""
https://leetcode.com/discuss/general-discussion/452863/how-do-you-deal-with-no-treeset-or-treemap-in-python
Python3 可以使用 sortedcontainers 库来实现 Java 中 TreeSet, TreeMap, Collections.sort(list)
"""

from sortedcontainers import SortedList

sl = SortedList(['e', 'a', 'c', 'd', 'b'])
print(sl)  # SortedList(['a', 'b', 'c', 'd', 'e'])
sl *= 100
print(sl.count('c'))  # 100
print(sl[-3:])  # ['e', 'e', 'e']

from sortedcontainers import SortedDict

sd = SortedDict({'c': 3, 'a': 1, 'b': 2})
print(sd)  # SortedDict({'a': 1, 'b': 2, 'c': 3})
print(sd.popitem(index=-1))  # ('c', 3)

from sortedcontainers import SortedSet

ss = SortedSet('abracadabra')
print(ss)  # SortedSet(['a', 'b', 'c', 'd', 'r'])
print(ss.bisect_left('c'))  # 2
print(ss.bisect_right('c'))  # 3
print(ss.bisect_left('f'))  # 4
print(ss.bisect_right('f'))  # 4