示例#1
0
class SummaryRanges:

    def __init__(self):
        self.mp = SortedDict()

    def addNum(self, val: int) -> None:
        n = len(self.mp)
        ridx = self.mp.bisect_right(val)
        lidx = n if ridx == 0 else ridx - 1
        keys = self.mp.keys()
        values = self.mp.values()
        if lidx != n and ridx != n and values[lidx][1] + 1 == val and values[ridx][0] - 1 == val:
            self.mp[keys[lidx]][1] = self.mp[keys[ridx]][1]
            self.mp.pop(keys[ridx])
        elif lidx != n and val <= values[lidx][1] + 1:
            self.mp[keys[lidx]][1] = max(val, self.mp[keys[lidx]][1])
        elif ridx != n and val >= values[ridx][0] - 1:
            self.mp[keys[ridx]][0] = min(val, self.mp[keys[ridx]][0])
        else:
            self.mp[val] = [val, val]

    def getIntervals(self) -> List[List[int]]:
        return list(self.mp.values())


# # Your SummaryRanges object will be instantiated and called as such:
# # obj = SummaryRanges()
# # obj.addNum(val)
# # param_2 = obj.getIntervals()
示例#2
0
class SummaryRanges:
    def __init__(self):
        self.mp = SortedDict()

    def addNum(self, val: int) -> None:
        n = len(self.mp)
        ridx = self.mp.bisect_right(val)
        lidx = n if ridx == 0 else ridx - 1
        keys = self.mp.keys()
        values = self.mp.values()
        if (
            lidx != n
            and ridx != n
            and values[lidx][1] + 1 == val
            and values[ridx][0] - 1 == val
        ):
            self.mp[keys[lidx]][1] = self.mp[keys[ridx]][1]
            self.mp.pop(keys[ridx])
        elif lidx != n and val <= values[lidx][1] + 1:
            self.mp[keys[lidx]][1] = max(val, self.mp[keys[lidx]][1])
        elif ridx != n and val >= values[ridx][0] - 1:
            self.mp[keys[ridx]][0] = min(val, self.mp[keys[ridx]][0])
        else:
            self.mp[val] = [val, val]

    def getIntervals(self) -> List[List[int]]:
        return list(self.mp.values())
示例#3
0
def trapezoid_decomposition_linear(polygons):
    """
    Keep track of which lines to add to GUI, keep track of the point_vertices.
    """
    # Enumerate all the edges and iteratively build up the set of trapezoids
    # Add a vertical line for each point in the polygon
    all_polygons = np.concatenate(polygons, axis=0)
    vertical_lines = SortedDict(
        {x[0]: [x[1], 1000000, 0]
         for x in all_polygons})

    # Loop over Polygons to determine end-points
    for polygon in polygons:
        start_vertex = polygon[0]
        for vertex in polygon[1:]:
            # find the lines in front of the smaller
            x_start = start_vertex[0]
            x_curr = vertex[0]
            start_idx = vertical_lines.bisect_right(min(x_start, x_curr))
            end_idx = vertical_lines.bisect_left(max(x_start, x_curr))
            x_vals = vertical_lines.keys()
            for i in range(start_idx, end_idx):
                x = x_vals[i]
                if x < min(x_start, x_curr) or x > max(x_start, x_curr):
                    continue
                y, top, bottom = vertical_lines[x]
                y_val = linear_interpolation(start_vertex, vertex, x)
                if y_val > y and y_val < top:
                    vertical_lines[x][1] = y_val
                elif y_val < y and y_val > bottom:
                    vertical_lines[x][2] = y_val
            start_vertex = vertex
    return vertical_lines
示例#4
0
    def canAttendMeetings(self, intervals: List[List[int]]) -> bool:

        points = SortedDict()
        for start, end in intervals:
            # print(start, end)
            # print(points)
            i_start = points.bisect_right(start)
            i_end = points.bisect_left(end)
            # print("i_start", i_start)
            # print("i_end", i_end)
            if i_end != i_start:
                return False
            if i_start > 0 and points.peekitem(i_start-1)[1] == 1:
                return False

            if points.get(start) == -1:
                del points[start]
            else:
                points[start] = 1

            if points.get(end) == 1:
                del points[end]
            else:
                points[end] = -1
        return True
示例#5
0
def test_bisect_key():
    temp = SortedDict(modulo, ((val, val) for val in range(100)))
    temp._reset(7)
    assert all(temp.bisect(val) == ((val % 10) + 1) * 10 for val in range(100))
    assert all(
        temp.bisect_right(val) == ((val % 10) + 1) * 10 for val in range(100))
    assert all(temp.bisect_left(val) == (val % 10) * 10 for val in range(100))
class SummaryRanges:
    def __init__(self):
        self.intervals = SortedDict()  # key为区间左边界,value为区间右边界

    def addNum(self, val: int) -> None:
        len_ = len(self.intervals)
        keys_ = self.intervals.keys()
        values_ = self.intervals.values()

        # 比val大的最小区间:val < keys[r1] if r1 != len_
        r1 = self.intervals.bisect_right(val)
        l1 = len_ if r1 == 0 else (r1 - 1)

        if l1 != len_ and keys_[l1] <= val <= values_[l1]:
            return  # val已经被l1对应的区间包含

        lconnect = l1 != len_ and (values_[l1] + 1 == val)  # 和左边区间挨上了
        rconnect = r1 != len_ and (keys_[r1] - 1 == val)  # 和右边区间挨上了
        if lconnect and rconnect:
            begin, end = keys_[l1], values_[r1]
            del self.intervals[keys_[r1]]
            self.intervals[begin] = end
        elif lconnect:
            begin = keys_[l1]
            self.intervals[begin] = val
        elif rconnect:
            end = values_[r1]
            del self.intervals[keys_[r1]]
            self.intervals[val] = end
        else:
            self.intervals[val] = val

    def getIntervals(self) -> List[List[int]]:
        return list(self.intervals.items())
示例#7
0
class MyCalendar:
    def __init__(self):
        self.sd = SortedDict()

    def book(self, start: int, end: int) -> bool:
        idx = self.sd.bisect_right(start)
        if idx < len(self.sd) and end > self.sd.values()[idx]:
            return False
        self.sd[end] = start
        return True
class RangeModule:

    def __init__(self):
        self.data = SortedDict()

    def addRange(self, left: int, right: int) -> None:
        l, r = self.data.bisect(left), self.data.bisect(right)
        if l != 0:
            # move L to the left by 1 this will point to the lower bound
            l -= 1
            # if the left is larger than the previous interval we need to move it up
            if self.data.peekitem(l)[1] < left:
                l += 1
        if l != r:
            # given the adjust left and right intervals we need to check if a merge needs to happen. we take
            # the min of the left intervals and the max of the right intervals to maximize the interval size.
            left, right = min(left, self.data.peekitem(l)[0]), max(right, self.data.peekitem(r-1)[1])
            # now that we have the new interval we need ot pop the redundant intervals
            for _ in range(l, r):
                self.data.popitem(l)
        # insert the new interval
        self.data[left] = right
        print(self.data)

    def queryRange(self, left: int, right: int) -> bool:
        l, r = self.data.bisect_right(left), self.data.bisect_right(right)
        print("l == 0: ",(l == 0), "self.data.peekitem(l-1)[1] < right: ",(self.data.peekitem(l-1)[1] < right))
        if l == 0 or self.data.peekitem(l-1)[1] < right: return False

        return True

    def removeRange(self, left: int, right: int) -> None:
        l, r = self.data.bisect_right(left), self.data.bisect_right(right)
        if l != 0:
            l -= 1
            if self.data.peekitem(l)[1] < left:
                l += 1
        if l != r:
            ll, rr = min(left, self.data.peekitem(l)[0]), max(right, self.data.peekitem(r-1)[1])
            for _ in range(l, r):
                self.data.popitem(l)
            if ll < left: self.data[ll] = left
            if right < rr: self.data[right] = rr
示例#9
0
class GradeTable:
    def __init__(self, path0=[]):
        self.table = SortedDict()
        self.table[0] = 0, path0
        return

    def decide(self, m):
        i = self.table.bisect_right(m) - 1
        return self.table.items()[i]

    def add_grade(self, cost, gain, path=[]):
        i = self.table.bisect_right(cost) - 1
        g, p = self.table.values()[i]
        if g >= gain:
            return False  # 现有的比新档次还好
        self.table[cost] = gain, path
        return True

    def __repr__(self):
        return repr(self.table)
示例#10
0
 def maxDepthBST(self, order):
     depths = SortedDict()
     depths[0] = 0
     depths[10**5 + 1] = 0
     for x in order:
         i = depths.bisect_right(x)
         vals = depths.values()
         left = vals[i - 1]
         right = vals[i]
         depths[x] = max(left, right) + 1
     return max(depths.values())
示例#11
0
class RangeModuleDict:
    def __init__(self):
        self.data = SortedDict()

    def addRange(self, left: int, right: int) -> None:
        l = self.data.bisect(left)
        r = self.data.bisect(right)
        if l != 0:
            l -= 1
            if self.data.peekitem(l)[1] < left:
                l += 1
        if l != r:
            left = min(left, self.data.peekitem(l)[0])
            right = max(right, self.data.peekitem(r - 1)[1])
            for _ in range(l, r):
                self.data.popitem(l)
        self.data[left] = right

    def queryRange(self, left: int, right: int) -> bool:
        l = self.data.bisect_right(left)
        r = self.data.bisect_right(right)
        if l == 0 or self.data.peekitem(l - 1)[1] < right:
            return False
        return True

    def removeRange(self, left: int, right: int) -> None:
        l = self.data.bisect_right(left)
        r = self.data.bisect_right(right)
        if l != 0:
            l -= 1
            if self.data.peekitem(l)[1] < left:
                l += 1
        if l != r:
            minLeft = min(left, self.data.peekitem(l)[0])
            maxRight = max(right, self.data.peekitem(r - 1)[1])
            for _ in range(l, r):
                self.data.popitem(l)
            if minLeft < left:
                self.data[minLeft] = left
            if right < maxRight:
                self.data[right] = maxRight
示例#12
0
    def oddEvenJumps(self, A: List[int]) -> int:
        INT_MIN = -(2**31)
        INT_MAX = 2**31 - 1

        a = A
        n = len(a)

        mm = SortedDict()
        oj = [-1 for i in range(n)]
        for i in range(n - 1, 0, -1):
            mm[a[i]] = i
            j = mm.bisect_left(a[i - 1])
            if j == len(mm):
                continue
            j = mm.iloc[j]
            oj[i - 1] = mm[j]

        mm = SortedDict()
        ej = [-1 for i in range(n)]
        for i in range(n - 1, 0, -1):
            mm[a[i]] = i
            j = mm.bisect_right(a[i - 1]) - 1
            if j == -1:
                continue
            j = mm.iloc[j]
            ej[i - 1] = mm[j]

        dp = {}

        def dfs(idx, odd):
            nonlocal dp

            if idx == n - 1:
                return True
            if (idx, odd) in dp:
                return dp[(idx, odd)]
            idx1 = oj[idx] if odd else ej[idx]
            if idx1 == -1:
                dp[(idx, odd)] = False
            else:
                dp[(idx, odd)] = dfs(idx1, not odd)
            return dp[(idx, odd)]

        res = 0
        for i in range(n):
            if dfs(i, True):
                res += 1
        return res
示例#13
0
class NameMap:
    """
  Class to store information about the 
  Statistics of names.
  For each name type, there is a dict
  for each name, there is a list of name %, cumulative %, rank
  """
    def __init__(self, nameFile: str):
        """
    create a new list of names with data
    given an index within range of 0..num-1
    """
        try:
            names = open(filePath + nameFile)
        except IOError:
            print("Error opening file:" + filePath + nameFile)
        self.namemap = SortedDict()
        for line in names:
            nameData = nameEntry(line)
            self.namemap[nameData.name] = nameData
        return

    def lookup(self, name):
        """
    lookup name in map
    return nameEntry
    else return none
    """
        return self.namemap.get(name)

    def lookup10(self, name):
        """
    lookup name in specified index
    return list of <name,[%,%cum,rank]> if in list
    else return none
    """
        i = self.namemap.bisect_right(name)
        low = max(0, i - 5)
        high = min(len(self.namemap), i + 5)
        result = []
        for j in range(low, high):
            result.append(self.namemap.peekitem(j))
        return result
示例#14
0
def populate_component_matrix(paths: List[Path],
                              schematic: PangenomeSchematic):
    # the loops are 1) paths, and then 2) schematic.components
    # paths are in the same order as schematic.path_names
    for i, path in enumerate(paths):
        sorted_bins = SortedDict((bin.bin_id, bin) for bin in path.bins)
        values = list(sorted_bins.values())
        for component in schematic.components:
            from_id = sorted_bins.bisect_left(component.first_bin)
            to_id = sorted_bins.bisect_right(component.last_bin)
            relevant = values[from_id:to_id]
            padded = []
            if relevant:
                padded = [[]] * (component.last_bin - component.first_bin + 1)
                for bin in relevant:
                    padded[bin.bin_id - component.first_bin] =  \
                        Bin(bin.coverage, bin.inversion_rate, bin.first_nucleotide, bin.last_nucleotide)
            component.matrix.append(
                padded)  # ensure there's always 1 entry for each path
    print("Populated Matrix per component per path.")
    populate_component_occupancy(schematic)
示例#15
0
def calcGistFeatures(bytes, kernels):
    # Create byte array
    byteImage = []
    for row in bytes:
        rowList = row.split(' ')[1:]
        if len(rowList) != 16:
            continue
        for item in rowList:
            if item == "??":
                byteImage.append(257)
            else:
                byteImage.append(int(item, 16))

    # Reshape to image
    totalSizeKB = len(byteImage) / 1024
    byteImage = np.array(byteImage)
    widthMap = SortedDict({
        0: 64,
        30: 128,
        60: 256,
        100: 384,
        200: 512,
        500: 768,
        1000: 1024
    })
    width = widthMap.values()[widthMap.bisect_right(totalSizeKB) - 1]
    height = int(len(byteImage) / width)
    byteImage = byteImage[:width * height]
    byteImage = byteImage.reshape(width, -1)
    byteImage = transform.resize(byteImage, (64, 64), preserve_range=True)
    # print(byteImage.shape)
    feat = compute_feats(byteImage, kernels)
    # print(feat.shape)

    # Downsample image to 64x64
    # byteImage = block_reduce(byteImage, block_size=(int(byteImage.shape[0]/64), int(byteImage.shape[1]/64)), func=np.mean)
    # print(byteImage.shape)
    return feat
 def _rolling(s, aa, bb):
     assert aa == 0, f"{aa} -- {phi}"
     # Interpolate the whole signal to include pivot points
     d = SortedDict(
         {t: s[t][phi.arg]
          for t in s.times() if phi.arg in s[t]})
     for t in set(d.keys()):
         idx = max(d.bisect_right(t) - 1, 0)
         key = d.keys()[idx]
         i = t - bb - aa + dt
         if s.start <= i < s.end:
             d[i] = s[key][phi.arg]
     # Iterate over rolling window
     v = []
     for t in reversed(d):
         v.append((t, d[t]))
         while not (t + aa <= v[0][0] < t + bb):
             del v[0]
         x = [i for j, i in v]
         if len(x) > 0:
             yield t, logic.tnorm(x)
         else:
             yield t, d[t]
示例#17
0
    def oddEvenJumps(self, A: List[int]) -> int:
        n = len(A)
        m = SortedDict()
        dp = [[0] * 2 for _ in range(n)]
        dp[n - 1][0] = dp[n - 1][1] = 1
        m[A[n - 1]] = n - 1
        res = 1

        for i in range(n - 2, -1, -1):
            # return index of lower bound, eg, first item >= A[i]
            # bisect_left return item >= value
            o = m.bisect_left(A[i])
            if o != len(m):
                dp[i][0] = dp[m.items()[o][1]][1]
            # index of first item <= A[i]
            # bisect_right return item > val
            # so e - 1 represents item <= val
            e = m.bisect_right(A[i])
            if e != 0:
                dp[i][1] = dp[m.items()[e - 1][1]][0]
            if dp[i][0]:
                res += 1
            m[A[i]] = i
        return res
def test_bisect():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert temp.bisect_left('a') == 0
    assert temp.bisect_right('f') == 6
    assert temp.bisect('f') == 6
示例#19
0
class GrowSpaceSortedEnv(gym.Env):
    def __init__(self,
                 width=DEFAULT_RES,
                 height=DEFAULT_RES,
                 light_dif=LIGHT_DIFFUSION):
        self.width = width
        self.height = height
        self.seed()
        self.light_dif = light_dif
        self.action_space = gym.spaces.Discrete(
            3)  # L, R, keep of light paddle
        self.observation_space = gym.spaces.Box(0,
                                                255,
                                                shape=(height, width, 3),
                                                dtype=np.uint8)
        self.steps = 0

        # data format for branches: they are indexed/sorted by x_end position and each
        # key has a list of values that are [y_end, x_start, y_start, children]

        self.branches = SortedDict()
        self.points = SortedDict()

    def seed(self, seed=None):
        return [np.random.seed(seed)]

    def light_move_R(self):
        if np.around(
                self.light_left,
                1) >= 1 - LIGHT_WIDTH - LIGHT_STEP:  # limit of coordinates
            self.light_left = 1 - LIGHT_WIDTH  # stay put
        else:
            self.light_left += LIGHT_STEP  # move by .1 right

    def light_move_L(self):
        if np.around(self.light_left, 1) <= LIGHT_STEP:  # limit of coordinates
            self.light_left = 0
        else:
            self.light_left -= LIGHT_STEP  # move by .1 left

    def find_closest_branch(self, point_x, branches):
        branch_names = []
        branch_distances = []
        # prefilter by x
        if len(branches) > MAX_BRANCHES:
            branches_trimmed = sample(branches, MAX_BRANCHES)
        else:
            branches_trimmed = branches
        for branch in branches_trimmed:
            dist_x = branch - point_x
            if np.abs(dist_x) <= MAX_GROW_DIST:
                # we got a potential candidate - now let's check Y
                dist_y = self.branches[branch][0] - self.points[point_x]
                if np.abs(dist_y) <= MAX_GROW_DIST:
                    dist = norm((dist_x, dist_y))
                    if dist <= MAX_GROW_DIST:
                        branch_names.append(branch)
                        branch_distances.append(dist)
        if len(branch_distances) == 0:
            return None, None
        argmin = np.argmin(branch_distances)
        return branch_names[argmin], branch_distances[argmin]

    def grow_plant(self):
        points_filtered = list(
            self.get_points_in_range(self.light_left - MAX_GROW_DIST,
                                     self.light_right + MAX_GROW_DIST))
        branches_filtered = list(
            self.get_branches_in_range(self.light_left, self.light_right))

        growths = {}  # will have the format: [(branch, target_x)]

        for point in points_filtered:
            closest_branch, dist = self.find_closest_branch(
                point, branches_filtered)
            if closest_branch is None:
                continue
            if dist < MIN_GROW_DIST:
                self.points.pop(point)
            elif dist < MAX_GROW_DIST:
                if closest_branch not in growths:
                    growths[closest_branch] = [point]
                else:
                    growths[closest_branch].append(point)

        for branch, points in growths.items():
            end_x = (branch +
                     (sum(points) / len(points) - branch) * BRANCH_LENGTH
                     )  # alternatively sum(poins)/len(points)
            branch_y = self.branches[branch][0]
            point_ys = [self.points[p] for p in points]
            end_y = branch_y + (sum(point_ys) / len(point_ys) -
                                branch_y) * BRANCH_LENGTH
            while end_x in self.branches:
                end_x += EPSILON  # keys need to be unique in branches dict
            self.branches[end_x] = [end_y, branch, self.branches[branch][0], 0]

        # update_all_branch_widths(branches)

    def get_points_in_range(self, start, end):
        return self.points.irange(start, end)  # this is dark SortedDict magic

    def get_branches_in_range(self, start, end):
        return self.branches.irange(start,
                                    end)  # this is dark SortedDict magic

    def branch_bisect_range(self, lower, upper):
        start = self.branches.bisect(lower)
        end = self.branches.bisect_right(upper)
        return self.branches[start:end]

    def get_branch_start_end_thiccness(self, end_x):
        end_y, start_x, start_y, children = self.branches[end_x]
        thicc = ir((children + 1) * BRANCH_THICCNESS * self.width)
        return (
            (ir(start_x * self.width), ir(start_y * self.height)),
            (ir(end_x * self.width), ir(end_y * self.height)),
            thicc,
        )

    def get_observation(self, debug_show_scatter=False):
        # new empty image
        img = np.zeros((self.height, self.width, 3), dtype=np.uint8)

        # place light as rectangle
        x1 = ir(self.light_left * self.width)
        x2 = ir(self.light_right * self.width)
        cv2.rectangle(img,
                      pt1=(x1, 0),
                      pt2=(x2, self.height),
                      color=LIGHT_COLOR,
                      thickness=-1)

        if debug_show_scatter:
            points_filtered = self.get_points_in_range(self.light_left,
                                                       self.light_right)
            for k in list(points_filtered):
                x = ir(k * self.width)
                y = ir(self.points[k] * self.height)
                cv2.circle(img,
                           center=(x, y),
                           radius=POINT_RADIUS,
                           color=POINT_COLOR,
                           thickness=-1)

        # Draw plant as series of lines (1 branch = 1 line)
        for branch_x_end in self.branches.keys():
            start, end, thiccness = self.get_branch_start_end_thiccness(
                branch_x_end)
            cv2.line(img,
                     pt1=start,
                     pt2=end,
                     color=PLANT_COLOR,
                     thickness=thiccness)

        # place goal as filled circle with center and radius
        # also important - place goal last because must be always visible
        x = ir(self.target[0] * self.width)
        y = ir(self.target[1] * self.height)
        cv2.circle(img,
                   center=(x, y),
                   radius=ir(0.03 * self.width),
                   color=(0, 0, 255),
                   thickness=-1)

        # flip image, because plant grows from the bottom, not the top
        img = cv2.flip(img, 0)

        return img

    def reset(self):
        random_start = np.random.rand()  # is in range [0,1
        self.branches.clear()
        self.points.clear()

        self.branches[random_start] = [FIRST_BRANCH_HEIGHT, random_start, 0, 0]

        self.target = [np.random.uniform(0, 1), np.random.uniform(0.8, 1)]
        if random_start >= (1 - LIGHT_WIDTH / 2):
            self.light_left = 1 - LIGHT_WIDTH
        elif random_start <= LIGHT_WIDTH / 2:
            self.light_left = 0
        else:
            self.light_left = random_start - (LIGHT_WIDTH / 2)

        self.light_right = self.light_left + LIGHT_WIDTH

        points_x = np.random.uniform(0, 1, self.light_dif)
        points_y = np.random.uniform(FIRST_BRANCH_HEIGHT + 0.1, 1,
                                     self.light_dif)

        for i in range(self.light_dif):
            while points_x[i] in self.points:
                points_x[i] += EPSILON
            self.points[points_x[i]] = points_y[i]

        self.steps = 0

        return self.get_observation()

    def step(self, action):
        # Two possible actions, move light left or right

        if action == 0:
            self.light_move_L()

        if action == 1:
            self.light_move_R()

        self.light_right = self.light_left + LIGHT_WIDTH

        if action == 2:
            # then we keep the light in place
            pass

        self.grow_plant()

        # # Calculate distance to target
        # reward = 1 / self.distance_target(tips)

        ####### TODO

        reward = 0  # TODO

        ####### TODO

        # Render image of environment at current state
        observation = self.get_observation()  # image

        done = False  # because we don't have a terminal condition
        misc = {
        }  # (optional) additional information about plant/episode/other stuff, leave empty for now
        # print("steps:", self.steps)    # sanity check
        self.steps += 1
        return observation, reward, done, misc

    def render(self,
               mode="human",
               debug_show_scatter=False):  # or mode="rgb_array"
        img = self.get_observation(debug_show_scatter)

        if mode == "human":
            cv2.imshow("plant", img)  # create opencv window to show plant
            cv2.waitKey(
                1)  # this is necessary or the window closes immediately
        else:
            return img
示例#20
0
def test_bisect():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert temp.bisect_left('a') == 0
    assert temp.bisect_right('f') == 6
    assert temp.bisect('f') == 6
示例#21
0
class ColorCodePatchBuilder(PickablePatchBuilder):
    """
    The patch generator build the matplotlib patches for each
    capability node.

    The nodes are rendered as lines with a different color depending
    on the permission bits of the capability. The builder produces
    a LineCollection for each combination of permission bits and
    creates the lines for the nodes.
    """
    def __init__(self, figure):
        super(ColorCodePatchBuilder, self).__init__(figure, None)

        self.y_unit = 10**-6
        """Unit on the y-axis"""

        # permission composition shorthands
        load_store = CheriCapPerm.LOAD | CheriCapPerm.STORE
        load_exec = CheriCapPerm.LOAD | CheriCapPerm.EXEC
        store_exec = CheriCapPerm.STORE | CheriCapPerm.EXEC
        load_store_exec = (CheriCapPerm.STORE | CheriCapPerm.LOAD
                           | CheriCapPerm.EXEC)

        self._collection_map = {
            0: [],
            CheriCapPerm.LOAD: [],
            CheriCapPerm.STORE: [],
            CheriCapPerm.EXEC: [],
            load_store: [],
            load_exec: [],
            store_exec: [],
            load_store_exec: [],
            "call": [],
        }
        """Map capability permission to the set where the line should go"""

        self._colors = {
            0: colorConverter.to_rgb("#bcbcbc"),
            CheriCapPerm.LOAD: colorConverter.to_rgb("k"),
            CheriCapPerm.STORE: colorConverter.to_rgb("y"),
            CheriCapPerm.EXEC: colorConverter.to_rgb("m"),
            load_store: colorConverter.to_rgb("c"),
            load_exec: colorConverter.to_rgb("b"),
            store_exec: colorConverter.to_rgb("g"),
            load_store_exec: colorConverter.to_rgb("r"),
            "call": colorConverter.to_rgb("#31c648"),
        }
        """Map capability permission to line colors"""

        self._patches = None
        """List of generated patches"""

        self._node_map = SortedDict()
        """Maps the Y axis coordinate to the graph node at that position"""

    def _build_patch(self, node_range, y, perms):
        """
        Build patch for the given range and type and add it
        to the patch collection for drawing
        """
        line = [(node_range.start, y), (node_range.end, y)]

        if perms is None:
            perms = 0
        rwx_perm = perms & (CheriCapPerm.LOAD | CheriCapPerm.STORE
                            | CheriCapPerm.EXEC)
        self._collection_map[rwx_perm].append(line)

    def _build_call_patch(self, node_range, y, origin):
        """
        Build patch for a node representing a system call
        This is added to a different collection so it can be
        colored differently.
        """
        line = [(node_range.start, y), (node_range.end, y)]
        self._collection_map["call"].append(line)

    def inspect(self, node):
        """
        Inspect a graph vertex and create the patches for it.
        """
        if node.cap.bound < node.cap.base:
            logger.warning("Skip overflowed node %s", node)
            return
        node_y = node.cap.t_alloc * self.y_unit
        node_box = transforms.Bbox.from_extents(node.cap.base, node_y,
                                                node.cap.bound, node_y)

        self._bbox = transforms.Bbox.union([self._bbox, node_box])
        keep_range = Range(node.cap.base, node.cap.bound, Range.T_KEEP)
        if node.origin == CheriNodeOrigin.SYS_MMAP:
            self._build_call_patch(keep_range, node_y, node.origin)
        else:
            self._build_patch(keep_range, node_y, node.cap.permissions)

        self._node_map[node.cap.t_alloc] = node

        #invalidate collections
        self._patches = None

    def get_patches(self):
        if self._patches:
            return self._patches
        self._patches = []
        for key, collection in self._collection_map.items():
            coll = collections.LineCollection(collection,
                                              colors=[self._colors[key]],
                                              linestyle="solid")
            self._patches.append(coll)
        return self._patches

    def get_legend(self):
        if not self._patches:
            self.get_patches()
        legend = ([], [])
        for patch, key in zip(self._patches, self._collection_map.keys()):
            legend[0].append(patch)
            if key == "call":
                legend[1].append("mmap")
            else:
                perm_string = ""
                if key & CheriCapPerm.LOAD:
                    perm_string += "R"
                if key & CheriCapPerm.STORE:
                    perm_string += "W"
                if key & CheriCapPerm.EXEC:
                    perm_string += "X"
                if perm_string == "":
                    perm_string = "None"
                legend[1].append(perm_string)
        return legend

    def on_click(self, event):
        """
        Attempt to retreive the data in less than O(n) for better
        interactivity at the expense of having to hold a dictionary of
        references to nodes for each t_alloc.
        Note that t_alloc is unique for each capability node as it
        is the cycle count, so it can be used as the key.
        """
        ax = event.inaxes
        if ax is None:
            return

        # back to data coords without scaling
        y_coord = int(event.ydata / self.y_unit)
        y_max = self._bbox.ymax / self.y_unit
        # tolerance for y distance, 0.25 units
        epsilon = 0.25 / self.y_unit

        # try to get the node closer to the y_coord
        # in the fast way
        # For now fall-back to a reduced linear search but would be
        # useful to be able to index lines with an R-tree?
        idx_min = self._node_map.bisect_left(max(0, y_coord - epsilon))
        idx_max = self._node_map.bisect_right(min(y_max, y_coord + epsilon))
        iter_keys = self._node_map.islice(idx_min, idx_max)
        # the closest node to the click position
        # initialize it with the first node in the search range
        try:
            pick_target = self._node_map[next(iter_keys)]
        except StopIteration:
            # no match found
            ax.set_status_message("")
            return

        for key in iter_keys:
            node = self._node_map[key]
            if (node.cap.base <= event.xdata and node.cap.bound >= event.xdata
                    and abs(y_coord - key) <
                    abs(y_coord - pick_target.cap.t_alloc)):
                # the click event is within the node bounds and
                # the node Y is closer to the click event than
                # the previous pick_target
                pick_target = node
        ax.set_status_message(pick_target)
示例#22
0
class LevelTrace(object):
    """ Traces the level of some entity across a time span """

    def __init__(self, trace=None):
        """ Creates a new level trace, possibly copying from an existing object. """

        if trace is None:
            self._trace = SortedDict()
        elif isinstance(trace, LevelTrace):
            self._trace = SortedDict(trace._trace)
        else:
            self._trace = SortedDict(trace)

        # Make sure trace is terminated (returns to 0)
        if len(self._trace) > 0 and self._trace[self._trace.keys()[-1]] != 0:
            raise ValueError(
                "Trace not terminated - ends with {}:{}!".format(
                    self._trace.keys()[-1], self._trace[self._trace.keys()[-1]])
                )

    def __repr__(self):
        items = ', '.join(["{!r}: {!r}".format(k, v)
                           for k, v in self._trace.items()])
        return "LevelTrace({{{}}})".format(items)

    def __eq__(self, other):
        return self._trace == other._trace

    def __neg__(self):
        return self.map(operator.neg)
    def __sub__(self, other):
        return self.zip_with(other, operator.sub)
    def __add__(self, other):
        return self.zip_with(other, operator.add)

    def start(self):
        """ Returns first non-null point in trace """
        if len(self._trace) == 0:
            return 0
        return self._trace.keys()[0]

    def end(self):
        """ Returns first point in trace that is null and only followed by nulls """
        if len(self._trace) == 0:
            return 0
        return self._trace.keys()[-1]

    def length(self):
        if len(self._trace) == 0:
            return 0
        return self.end() - self.start()

    def get(self, time):
        ix = self._trace.bisect_right(time) - 1
        if ix < 0:
            return 0
        else:
            (_, lvl) = self._trace.peekitem(ix)
            return lvl

    def map(self, fn):
        return LevelTrace({t: fn(v) for t, v in self._trace.items()})
    def map_key(self, fn):
        return LevelTrace(dict(fn(t, v) for t, v in self._trace.items()))

    def shift(self, time):
        return self.map_key(lambda t, v: (t + time, v))

    def __getitem__(self, where):

        # For non-slices defaults to get
        if not isinstance(where, slice):
            return self.get(where)
        if where.step is not None:
            raise ValueError("Stepping meaningless for LevelTrace!")

        # Limit
        res = LevelTrace(self)
        if where.start is not None and where.start > res.start():
            res.set(res.start(), where.start, 0)
        if where.stop is not None and where.stop < res.end():
            res.set(where.stop, res.end(), 0)

        # Shift, if necessary
        if where.start is not None:
            res = res.shift(-where.start)
        return res

    def set(self, start, end, level):
        """ Sets the level for some time range
        :param start: Start of range
        :param end: End of range
        :aram amount: Level to set
        """

        # Check errors, no-ops
        if start >= end:
            return

        # Determine levels at start (and before start)
        start_ix = self._trace.bisect_right(start) - 1
        prev_lvl = lvl = 0
        if start_ix >= 0:
            (t, lvl) = self._trace.peekitem(start_ix)
            # If we have no entry exactly at our start point, the
            # level was constant at this point before
            if start > t:
                prev_lvl = lvl
            # Otherwise look up previous level. Default 0 (see above)
            elif start_ix > 0:
                (_, prev_lvl) = self._trace.peekitem(start_ix-1)

        # Prepare start
        if prev_lvl == level:
            if start in self._trace:
                del self._trace[start]
        else:
            self._trace[start] = level

        # Remove all in-between states
        for time in list(self._trace.irange(start, end, inclusive=(False, False))):
            lvl = self._trace[time]
            del self._trace[time]

        # Add or remove end, if necessary
        if end not in self._trace:
            if lvl != level:
                self._trace[end] = lvl
        elif level == self._trace[end]:
            del self._trace[end]


    def add(self, start, end, amount):
        """ Increases the level for some time range
        :param start: Start of range
        :param end: End of range
        :aram amount: Amount to add to level
        """

        # Check errors, no-ops
        if start > end:
            raise ValueError("End needs to be after start!")
        if start == end or amount == 0:
            return

        # Determine levels at start (and before start)
        start_ix = self._trace.bisect_right(start) - 1
        prev_lvl = lvl = 0
        if start_ix >= 0:
            (t, lvl) = self._trace.peekitem(start_ix)
            # If we have no entry exactly at our start point, the
            # level was constant at this point before
            if start > t:
                prev_lvl = lvl
            # Otherwise look up previous level. Default 0 (see above)
            elif start_ix > 0:
                (_, prev_lvl) = self._trace.peekitem(start_ix-1)

        # Prepare start
        if prev_lvl == lvl + amount:
            del self._trace[start]
        else:
            self._trace[start] = lvl + amount

        # Update all in-between states
        for time in self._trace.irange(start, end, inclusive=(False, False)):
            lvl = self._trace[time]
            self._trace[time] = lvl + amount

        # Add or remove end, if necessary
        if end not in self._trace:
            self._trace[end] = lvl
        elif lvl + amount == self._trace[end]:
            del self._trace[end]

    def __delitem__(self, where):

        # Cannot set single values
        if not isinstance(where, slice):
            raise ValueError("Cannot set level for single point, pass an interval!")
        if where.step is not None:
            raise ValueError("Stepping meaningless for LevelTrace!")

        # Set range to zero
        start = (where.start if where.start is not None else self.start())
        end = (where.stop if where.stop is not None else self.end())
        self.set(start, end, 0)

    def __setitem__(self, where, value):

        # Cannot set single values
        if not isinstance(where, slice):
            raise ValueError("Cannot set level for single point, pass an interval!")
        if where.step is not None:
            raise ValueError("Stepping meaningless for LevelTrace!")

        # Setting a level trace?
        if isinstance(value, LevelTrace):

            # Remove existing data
            del self[where]
            if where.start is not None:
                if value.start() < 0:
                    raise ValueError("Level trace starts before 0!")
                value = value.shift(where.start)
            if where.stop is not None:
                if value.end() > where.stop:
                    raise ValueError("Level trace to set is larger than slice!")
            self._trace = (self + value)._trace

        else:

            # Otherwise set constant value
            start = (where.start if where.start is not None else self.start())
            end = (where.stop if where.stop is not None else self.end())
            self.set(start, end, value)

    def foldl1(self, start, end, fn):
        """
        Does a left-fold over the levels present in the given range. Seeds
        with level at start.
        """

        if start > end:
            raise ValueError("End needs to be after start!")
        val = self.get(start)
        start_ix = self._trace.bisect_right(start)
        end_ix = self._trace.bisect_left(end)
        for lvl in self._trace.values()[start_ix:end_ix]:
            val = fn(val, lvl)
        return val

    def minimum(self, start, end):
        """ Returns the lowest level in the given range """
        return self.foldl1(start, end, min)
    def maximum(self, start, end):
        """ Returns the highest level in the given range """
        return self.foldl1(start, end, max)

    def foldl_time(self, start, end, val, fn):
        """
        Does a left-fold over the levels present in the given range,
        also passing how long the level was held. Seed passed.
        """

        if start > end:
            raise ValueError("End needs to be after start!")

        last_time = start
        last_lvl = self.get(start)

        start_ix = self._trace.bisect_right(start)
        end_ix = self._trace.bisect_left(end)
        for time, lvl in self._trace.items()[start_ix:end_ix]:
            val = fn(val, time-last_time, last_lvl)
            last_time = time
            last_lvl = lvl

        return fn(val, end-last_time, last_lvl)

    def integrate(self, start, end):
        """ Returns the integral over a range (sum below level curve) """
        return self.foldl_time(start, end, 0,
                               lambda v, time, lvl: v + time * lvl)
    def average(self, start, end):
        """ Returns the average level over a given range """
        return self.integrate(start, end) / (end - start)

    def find_above(self, time, level):
        """Returns the first time larger or equal to the given start time
        where the level is at least the specified value.
        """

        if self.get(time) >= level:
            return time
        ix = self._trace.bisect_right(time)
        for t, lvl in self._trace.items()[ix:]:
            if lvl >= level:
                return t
        return None

    def find_below(self, time, level):
        """Returns the first time larger or equal to the given start time
        where the level is less or equal the specified value.
        """

        if self.get(time) <= level:
            return time
        ix = self._trace.bisect_right(time)
        for t, lvl in self._trace.items()[ix:]:
            if lvl <= level:
                return t
        return None

    def find_below_backward(self, time, level):
        """Returns the last time smaller or equal to the given time where
        there exists a region to the left where the level is below the
        given value.
        """

        last = time
        ix = self._trace.bisect_right(time)-1
        if ix >= 0:
            for t, lvl in self._trace.items()[ix::-1]:
                if lvl <= level and time > t:
                    return last
                last = t
        if level >= 0:
            return last
        return None

    def find_above_backward(self, time, level):
        """Returns the last time smaller or equal to the given time where
        there exists a region to the left where the level is below the
        given value.
        """

        last = time
        ix = self._trace.bisect_right(time)-1
        if ix >= 0:
            for t, lvl in self._trace.items()[ix::-1]:
                if lvl >= level and time > t:
                    return last
                last = t
        if level <= 0:
            return last
        return None

    def find_period_below(self, start, end, target, length):
        """Returns a period where the level is below the target for a certain
        length of time, within a given start and end time"""

        if start > end:
            raise ValueError("End needs to be after start!")
        if length < 0:
            raise ValueError("Period length must be larger than zero!")

        period_start = (start if self.get(start) <= target else None)

        start_ix = self._trace.bisect_right(start)
        end_ix = self._trace.bisect_left(end)
        for time, lvl in self._trace.items()[start_ix:end_ix]:
            # Period long enough?
            if period_start is not None:
                if time >= period_start + length:
                    return period_start
            # Not enough space until end?
            elif time + length > end:
               return None
            # Above target? Reset period
            if lvl > target:
                period_start = None
            else:
                if period_start is None:
                    period_start = time

        # Possible at end?
        if period_start is not None and period_start+length <= end:
            return period_start

        # Nothing found
        return None

    def zip_with(self, other, fn):

        # Simple cases
        if len(self._trace) == 0:
            return other.map(lambda x: fn(0, x))
        if len(other._trace) == 0:
            return self.map(lambda x: fn(x, 0))

        # Read first item from both sides
        left = self._trace.items()
        right = other._trace.items()
        left_ix = 0
        right_ix = 0
        left_val = 0
        right_val = 0
        last_val = 0

        trace = SortedDict()

        # Go through pairs
        while left_ix < len(left) and right_ix < len(right):

            # Next items
            lt,lv = left[left_ix]
            rt,rv = right[right_ix]

            # Determine what to do
            if lt < rt:
                v = fn(lv, right_val)
                if v != last_val:
                    last_val = trace[lt] = v
                left_val = lv
                left_ix += 1
            elif lt > rt:
                v = fn(left_val, rv)
                if v != last_val:
                    last_val = trace[rt] = v
                right_val = rv
                right_ix += 1
            else:
                v = fn(lv, rv)
                if v != last_val:
                    last_val = trace[lt] = v
                left_val = lv
                left_ix += 1
                right_val = rv
                right_ix += 1

        # Handle left-overs
        while left_ix < len(left):
            lt,lv = left[left_ix]
            v = fn(lv, right_val)
            if v != last_val:
                last_val = trace[lt] = v
            left_ix += 1
        while right_ix < len(right):
            rt,rv = right[right_ix]
            v = fn(left_val, rv)
            if v != last_val:
                last_val = trace[rt] = v
            right_ix += 1

        return LevelTrace(trace)
示例#23
0
class DeviceCounter():
    """
    Measure the number of available devices within a time period.
    """

    def __init__(self, start, end, local=False, debug=False, **kwargs):
        """
        Initialize a new `DeviceCounter` instance for the given range of time.

        Required positional arguments:

        :start: A python `datetime`, pandas `Timestamp`, or Unix timestamp for the beginning of the counting interval.

        :end: A python `datetime`, pandas `Timestamp`, or Unix timestamp for the end of the counting interval.

        Optional keyword arguments:

        :local: `False` (default) to assume Unix time; `True` to assume local time.

        :debug: `False` (default) to supress debug messages; `True` to print to stdout.
        """
        if start is None or end is None:
            raise TypeError(f"'NoneType' was unexpected for start and/or end. Expected datetime, Timestamp, or Unix timestamp")

        self.start = start
        self.end = end
        self._start = self._ts2int(start)
        self._end = self._ts2int(end)
        self.interval = CounterInterval(self._start, self._end)
        self.delta = self.interval.delta
        self.local = local
        self.debug = debug

        self._reset()

        if self.debug:
            print(f"self.interval: {self.interval}")
            print()

    def _reset(self):
        """
        Resets this counter with the initial interval.
        """
        self.counts = SortedDict({ self.interval : 0 })

        # debug info
        self.events = 0
        self.splits = 0
        self.counter = 0

    def _int2ts(self, i):
        """
        Convert :i: to a Timestamp
        """
        return pandas.Timestamp(i, unit="s")

    def _ts2int(self, ts):
        """
        Try to convert :ts: to a integer
        """
        try:
            return int(ts.timestamp())
        except:
            return int(ts)

    def _interval(self, key_index):
        """
        Get the Interval by index in the sorted key list
        """
        return self.counts.keys()[key_index]

    def _insertidx(self, start, end, default_index=0):
        """
        Get an insertion index for an interval with the given endpoints.
        """
        # the index for the closest known sub-interval to the event's timespan
        index = self.counts.bisect_right(CounterInterval(start, end or self._end)) - 1
        # using the start of the interval as the default
        return index if index >= 0 else default_index

    def count_event(self, event_start, event_end):
        """
        Increment the counter for the given interval of time.

        :event_start: A python `datetime`, pandas `Timestamp`, or Unix timestamp marking the beginning of the event interval.

        :event_end: A python `datetime`, pandas `Timestamp`, or Unix timestamp for the end of the event interval, or `None` for
        and event with an open interval.

        Performs a right-bisection on the counter intervals, assigning counts to increasingly
        finer slices based on the the event's timespan's intersection with the existing counter intervals.
        """
        event_start = self._ts2int(event_start)
        event_end = None if (event_end is None or event_end is pandas.NaT) else self._ts2int(event_end)
        to_remove = SortedSet()
        to_add = SortedDict()

        _counter = self.counter
        _splits = self.splits

        # get the next insertion index
        index = self._insertidx(start=event_start, end=event_end)

        # move the index to the right, splitting the existing intervals and incrementing counts along the way
        while index < len(self.counts) and (event_end is None or self._interval(index).start < event_end):
            interval = self._interval(index)
            count = self.counts[interval]
            start, end = interval.start, interval.end

            # the event has a closed timespan: [event_start, event_end]
            if event_end is not None:
                # event fully spans and contains the interval
                if event_start <= start and event_end >= end:
                    # increment the interval
                    to_add[CounterInterval(start, end)] = count + 1
                    self.counter += 1

                # event starts before the interval and overlaps from the left
                elif event_start <= start and event_end > start and event_end < end:
                    # subdivide and increment the affected sub-interval
                    # [start, end] -> [start, event_end]+, [event_end, end]
                    to_remove.add(interval)
                    to_add[CounterInterval(start, event_end)] = count + 1
                    to_add[CounterInterval(event_end, end)] = count
                    self.splits += 1
                    self.counter += 1

                # event starts in the interval and overlaps on the right
                elif event_start > start and event_start < end and event_end >= end:
                    # subdivide and increment the affected interval
                    # [start, end] -> [start, event_start], [event_start, end]+
                    to_remove.add(interval)
                    to_add[CounterInterval(start, event_start)] = count
                    to_add[CounterInterval(event_start, end)] = count + 1
                    self.splits += 1
                    self.counter += 1

                # event is fully within and contained by the interval
                elif event_start > start and event_end < end:
                    # subdivide and increment the affected interval
                    # [start, end] -> [start, event_start], [event_start, event_end]+, [event_end, end]
                    to_remove.add(interval)
                    to_add[CounterInterval(start, event_start)] = count
                    to_add[CounterInterval(event_start, event_end)] = count + 1
                    to_add[CounterInterval(event_end, end)] = count
                    self.splits += 2
                    self.counter += 1

            # the event has an open timespan: [event_start, )
            else:
                # event starts before the interval
                if event_start <= start:
                    # incrememnt the interval
                    to_add[CounterInterval(start, end)] = count + 1
                    self.counter += 1

                # event starts inside the interval
                elif event_start > start and event_start <= end:
                    # subdivide and increment the affected interval
                    # [start, end] -> [start, event_start], [event_start, end]+
                    to_remove.add(interval)
                    to_add[CounterInterval(start, event_start)] = count
                    to_add[CounterInterval(event_start, end)] = count + 1
                    self.splits += 1
                    self.counter += 1

            index += 1

        for r in to_remove:
            self.counts.pop(r)

        for k in to_add.keys():
            self.counts[k] = to_add[k]

        if self.debug:
            debug = {
                "start": event_start,
                "end": event_end,
                "index": index,
                "remove": len(to_remove),
                "split": int(self.splits - _splits),
                "add": len(to_add),
                "counter": int(self.counter - _counter)
            }
            print(", ".join([f"{k}: {v}" for k, v in debug.items()]))

        self.events += 1

        return self

    def count(self, data, predicate=None):
        """
        Count device availability observed in data, over this counter's interval.

        :data: A `pandas.DataFrame` of records from the availability view.

        :predicate: A function with 3 positional args: this `DeviceCounter`, an index, and corresponding row from :data:.
        This function will be called before the given row is evaluated; if `True`, the row is counted.

        :returns: This `DeviceCounter` instance.
        """
        if self.debug:
            print(f"Generating f(x) over [{self.start}, {self.end}] with {len(data)} input records")
            print()

        self._reset()

        assert(len(self.counts) == 1)
        assert(self.counts.keys()[0] == self.interval)

        scale = ceil(len(data) / 10)

        # using this counter's initial interval as a starting point,
        # subdivide based on the intersection of the interval from each event in the data
        # incrememting a counter for each sub-interval created along the way
        for index, row in data.iterrows():
            if self.debug and index % scale == 0:
                print(f"Processing {index + 1} of {len(data)}")

            if predicate is None or predicate(self, index, row):
                if self.local:
                    self.count_event(row["start_time_local"], row["end_time_local"])
                else:
                    self.count_event(row["start_time"], row["end_time"])

        if self.debug:
            print("Partitioning complete.")
            print(f"events: {self.events}, splits: {self.splits}, counter: {self.counter}")

        return self

    def partition(self):
        """
        Returns the current interval partition as a `pandas.DataFrame`.
        """
        partition = [{ "start": i.start,
                        "end": i.end,
                        "delta": i.delta,
                        "count": c,
                        "start_date": self._int2ts(i.start),
                        "end_date": self._int2ts(i.end) }
                    for i, c in self.counts.items()]

        return pandas.DataFrame.from_records(partition,
            columns=["start", "end", "delta", "count", "start_date", "end_date"])

    def delta_x(self):
        """
        :return: The ordered list of deltas for the given interval partition, or this interval's partition.
        """
        partition = self.partition()
        return partition["delta"]

    def norm(self):
        """
        Get the delta of the largest sub-interval in this interval's partition.
        """
        partition = self.partition()
        return max(self.delta_x())

    def dimension(self):
        """
        The number of sub-intervals in this interval's partition.
        """
        return len(self.partition())

    def average(self):
        """
        Estimate the average number of devices within this interval's partition.

        Use a Riemann sum to estimate, computing the area of each sub-interval in the partition:

        - height: the count of devices seen during that timeslice
        - width:  the length of the timeslice in seconds
        """
        partition = self.partition()

        if self.debug:
            print(f"Computing average across {self.dimension()} subintervals.")

        areas = partition.apply(lambda i: i["count"] * i["delta"], axis="columns")
        sigma = areas.agg("sum")

        if self.debug:
            print("sigma:", sigma)
            print("delta:", self.delta)

        # Compute the average value over this counter's interval
        return sigma / self.delta
示例#24
0
class DownloadTask(QObject):
    download_ready = Signal(QObject)
    download_not_ready = Signal(QObject)
    download_complete = Signal(QObject)
    download_failed = Signal(QObject)
    download_error = Signal(str)
    download_ok = Signal()

    download_finishing = Signal()
    copy_added = Signal(str)
    chunk_downloaded = Signal(
        str,  # obj_id
        str,  # str(offset) to fix offset >= 2**31
        int)  # length
    chunk_aborted = Signal()
    request_data = Signal(
        str,  # node_id
        str,  # obj_id
        str,  # str(offset) to fix offset >= 2**31
        int)  # length
    abort_data = Signal(
        str,  # node_id
        str,  # obj_id
        str)  # str(offset) to fix offset >= 2**31
    possibly_sync_folder_is_removed = Signal()
    no_disk_space = Signal(
        QObject,  # task
        str,  # display_name
        bool)  # is error
    wrong_hash = Signal(QObject)  # task)
    signal_info_rx = Signal(tuple)

    default_part_size = DOWNLOAD_PART_SIZE
    receive_timeout = 20  # seconds
    retry_limit = 2
    timeouts_limit = 2
    max_node_chunk_requests = 128
    end_race_timeout = 5.  # seconds

    def __init__(self,
                 tracker,
                 connectivity_service,
                 priority,
                 obj_id,
                 obj_size,
                 file_path,
                 display_name,
                 file_hash=None,
                 parent=None,
                 files_info=None):
        QObject.__init__(self, parent=parent)
        self._tracker = tracker
        self._connectivity_service = connectivity_service

        self.priority = priority
        self.size = obj_size
        self.id = obj_id
        self.file_path = file_path
        self.file_hash = file_hash
        self.download_path = file_path + '.download'
        self._info_path = file_path + '.info'
        self.display_name = display_name
        self.received = 0
        self.files_info = files_info

        self.hash_is_wrong = False
        self._ready = False
        self._started = False
        self._paused = False
        self._finished = False
        self._no_disk_space_error = False

        self._wanted_chunks = SortedDict()
        self._downloaded_chunks = SortedDict()
        self._nodes_available_chunks = dict()
        self._nodes_requested_chunks = dict()
        self._nodes_last_receive_time = dict()
        self._nodes_downloaded_chunks_count = dict()
        self._nodes_timeouts_count = dict()
        self._total_chunks_count = 0

        self._file = None
        self._info_file = None

        self._started_time = time()

        self._took_from_turn = 0
        self._received_via_turn = 0
        self._received_via_p2p = 0

        self._retry = 0

        self._limiter = None

        self._init_wanted_chunks()

        self._on_downloaded_cb = None
        self._on_failed_cb = None
        self.download_complete.connect(self._on_downloaded)
        self.download_failed.connect(self._on_failed)

        self._timeout_timer = QTimer(self)
        self._timeout_timer.setInterval(15 * 1000)
        self._timeout_timer.setSingleShot(False)
        self._timeout_timer.timeout.connect(self._on_check_timeouts)

        self._leaky_timer = QTimer(self)
        self._leaky_timer.setInterval(1000)
        self._leaky_timer.setSingleShot(True)
        self._leaky_timer.timeout.connect(self._download_chunks)

        self._network_limited_error_set = False

    def __lt__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__lt__(self, other)

        if self == other:
            return False

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id < other.id

            return self.size - self.received < other.size - other.received

        return self.priority > other.priority

    def __le__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__le__(self, other)

        if self == other:
            return True

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id < other.id

            return self.size - self.received < other.size - other.received

        return self.priority >= other.priority

    def __gt__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__gt__(self, other)

        if self == other:
            return False

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id > other.id

            return self.size - self.received > other.size - other.received

        return self.priority <= other.priority

    def __ge__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__ge__(self, other)

        if self == other:
            return True

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id > other.id

            return self.size - self.received > other.size - other.received

        return self.priority <= other.priority

    def __eq__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__eq__(self, other)

        return self.id == other.id

    def on_availability_info_received(self, node_id, obj_id, info):
        if obj_id != self.id or self._finished or not info:
            return

        logger.info(
            "availability info received, "
            "node_id: %s, obj_id: %s, info: %s", node_id, obj_id, info)

        new_chunks_stored = self._store_availability_info(node_id, info)
        if not self._ready and new_chunks_stored:
            if self._check_can_receive(node_id):
                self._ready = True
                self.download_ready.emit(self)
            else:
                self.download_error.emit('Turn limit reached')

        if self._started and not self._paused \
                and not self._nodes_requested_chunks.get(node_id, None):
            logger.debug("Downloading next chunk")
            self._download_next_chunks(node_id)
            self._clean_nodes_last_receive_time()
            self._check_download_not_ready(self._nodes_requested_chunks)

    def on_availability_info_failure(self, node_id, obj_id, error):
        if obj_id != self.id or self._finished:
            return

        logger.info(
            "availability info failure, "
            "node_id: %s, obj_id: %s, error: %s", node_id, obj_id, error)
        try:
            if error["err_code"] == "FILE_CHANGED":
                self.download_failed.emit(self)
        except Exception as e:
            logger.warning("Can't parse error message. Reson: %s", e)

    def start(self, limiter):
        if exists(self.file_path):
            logger.info("download task file already downloaded %s",
                        self.file_path)
            self.received = self.size
            self.download_finishing.emit()
            self.download_complete.emit(self)
            return

        self._limiter = limiter

        if self._started:
            # if we swapped task earlier
            self.resume()
            return

        self._no_disk_space_error = False
        if not self.check_disk_space():
            return

        logger.info("starting download task, obj_id: %s", self.id)
        self._started = True
        self._paused = False
        self.hash_is_wrong = False
        self._started_time = time()
        self._send_start_statistic()
        if not self._open_file():
            return

        self._read_info_file()

        for downloaded_chunk in self._downloaded_chunks.items():
            self._remove_from_chunks(downloaded_chunk[0], downloaded_chunk[1],
                                     self._wanted_chunks)

        self.received = sum(self._downloaded_chunks.values())
        if self._complete_download():
            return

        self._download_chunks()
        if not self._timeout_timer.isActive():
            self._timeout_timer.start()

    def check_disk_space(self):
        if self.size * 2 + get_signature_file_size(self.size) > \
                get_free_space_by_filepath(self.file_path):
            self._emit_no_disk_space()
            return False

        return True

    def pause(self, disconnect_cb=True):
        self._paused = True
        if disconnect_cb:
            self.disconnect_callbacks()
        self.stop_download_chunks()

    def resume(self, start_download=True):
        self._started_time = time()
        self._paused = False
        self.hash_is_wrong = False
        if start_download:
            self._started = True
            self._download_chunks()
            if not self._timeout_timer.isActive():
                self._timeout_timer.start()

    def cancel(self):
        self._close_file()
        self._close_info_file()
        self.stop_download_chunks()

        self._finished = True

    def clean(self):
        logger.debug("Cleaning download files %s", self.download_path)
        try:
            remove_file(self.download_path)
        except:
            pass
        try:
            remove_file(self._info_path)
        except:
            pass

    def connect_callbacks(self, on_downloaded, on_failed):
        self._on_downloaded_cb = on_downloaded
        self._on_failed_cb = on_failed

    def disconnect_callbacks(self):
        self._on_downloaded_cb = None
        self._on_failed_cb = None

    @property
    def ready(self):
        return self._ready

    @property
    def paused(self):
        return self._paused

    @property
    def no_disk_space_error(self):
        return self._no_disk_space_error

    def _init_wanted_chunks(self):
        self._total_chunks_count = math.ceil(
            float(self.size) / float(DOWNLOAD_CHUNK_SIZE))

        self._wanted_chunks[0] = self.size

    def _on_downloaded(self, task):
        if callable(self._on_downloaded_cb):
            self._on_downloaded_cb(task)
            self._on_downloaded_cb = None

    def _on_failed(self, task):
        if callable(self._on_failed_cb):
            self._on_failed_cb(task)
            self._on_failed_cb = None

    def on_data_received(self, node_id, obj_id, offset, length, data):
        if obj_id != self.id or self._finished:
            return

        logger.debug(
            "on_data_received for objId: %s, offset: %s, from node_id: %s",
            self.id, offset, node_id)

        now = time()
        last_received_time = self._nodes_last_receive_time.get(node_id, 0.)
        if node_id in self._nodes_last_receive_time:
            self._nodes_last_receive_time[node_id] = now

        self._nodes_timeouts_count.pop(node_id, 0)

        downloaded_count = \
            self._nodes_downloaded_chunks_count.get(node_id, 0) + 1
        self._nodes_downloaded_chunks_count[node_id] = downloaded_count

        # to collect traffic info
        node_type = self._connectivity_service.get_self_node_type()
        is_share = node_type == "webshare"
        # tuple -> (obj_id, rx_wd, rx_wr, is_share)
        if self._connectivity_service.is_relayed(node_id):
            # relayed traffic
            info_rx = (obj_id, 0, length, is_share)
        else:
            # p2p traffic
            info_rx = (obj_id, length, 0, is_share)
        self.signal_info_rx.emit(info_rx)

        if not self._is_chunk_already_downloaded(offset):
            if not self._on_new_chunk_downloaded(node_id, offset, length,
                                                 data):
                return

        else:
            logger.debug("chunk %s already downloaded", offset)

        requested_chunks = self._nodes_requested_chunks.get(
            node_id, SortedDict())
        if not requested_chunks:
            return

        self._remove_from_chunks(offset, length, requested_chunks)

        if not requested_chunks:
            self._nodes_requested_chunks.pop(node_id, None)

        requested_count = sum(requested_chunks.values()) // DOWNLOAD_CHUNK_SIZE
        if downloaded_count * 4 >= requested_count \
                and requested_count < self.max_node_chunk_requests:
            self._download_next_chunks(node_id, now - last_received_time)
            self._clean_nodes_last_receive_time()
            self._check_download_not_ready(self._nodes_requested_chunks)

    def _is_chunk_already_downloaded(self, offset):
        if self._downloaded_chunks:
            chunk_index = self._downloaded_chunks.bisect_right(offset)
            if chunk_index > 0:
                chunk_index -= 1

                chunk = self._downloaded_chunks.peekitem(chunk_index)
                if offset < chunk[0] + chunk[1]:
                    return True

        return False

    def _on_new_chunk_downloaded(self, node_id, offset, length, data):
        if not self._write_to_file(offset, data):
            return False

        self.received += length
        if self._connectivity_service.is_relayed(node_id):
            self._received_via_turn += length
        else:
            self._received_via_p2p += length

        new_offset = offset
        new_length = length

        left_index = self._downloaded_chunks.bisect_right(new_offset)
        if left_index > 0:
            left_chunk = self._downloaded_chunks.peekitem(left_index - 1)
            if left_chunk[0] + left_chunk[1] == new_offset:
                new_offset = left_chunk[0]
                new_length += left_chunk[1]
                self._downloaded_chunks.popitem(left_index - 1)

        right_index = self._downloaded_chunks.bisect_right(new_offset +
                                                           new_length)
        if right_index > 0:
            right_chunk = self._downloaded_chunks.peekitem(right_index - 1)
            if right_chunk[0] == new_offset + new_length:
                new_length += right_chunk[1]
                self._downloaded_chunks.popitem(right_index - 1)

        self._downloaded_chunks[new_offset] = new_length

        assert self._remove_from_chunks(offset, length, self._wanted_chunks)

        logger.debug("new chunk downloaded from node: %s, wanted size: %s",
                     node_id, sum(self._wanted_chunks.values()))

        part_offset = (offset / DOWNLOAD_PART_SIZE) * DOWNLOAD_PART_SIZE
        part_size = min([DOWNLOAD_PART_SIZE, self.size - part_offset])
        if new_offset <= part_offset \
                and new_offset + new_length >= part_offset + part_size:
            if self._file:
                self._file.flush()
            self._write_info_file()

            self.chunk_downloaded.emit(self.id, str(part_offset), part_size)

        if self._complete_download():
            return False

        return True

    def _remove_from_chunks(self, offset, length, chunks):
        if not chunks:
            return False

        chunk_left_index = chunks.bisect_right(offset)
        if chunk_left_index > 0:
            left_chunk = chunks.peekitem(chunk_left_index - 1)
            if offset >= left_chunk[0] + left_chunk[1] \
                    and len(chunks) > chunk_left_index:
                left_chunk = chunks.peekitem(chunk_left_index)
            else:
                chunk_left_index -= 1
        else:
            left_chunk = chunks.peekitem(chunk_left_index)

        if offset >= left_chunk[0] + left_chunk[1] or \
                offset + length <= left_chunk[0]:
            return False

        chunk_right_index = chunks.bisect_right(offset + length)
        right_chunk = chunks.peekitem(chunk_right_index - 1)

        if chunk_right_index == chunk_left_index:
            to_del = [right_chunk[0]]
        else:
            to_del = list(chunks.islice(chunk_left_index, chunk_right_index))

        for chunk in to_del:
            chunks.pop(chunk)

        if left_chunk[0] < offset:
            if left_chunk[0] + left_chunk[1] >= offset:
                chunks[left_chunk[0]] = offset - left_chunk[0]

        if right_chunk[0] + right_chunk[1] > offset + length:
            chunks[offset + length] = \
                right_chunk[0] + right_chunk[1] - offset - length
        return True

    def on_data_failed(self, node_id, obj_id, offset, error):
        if obj_id != self.id or self._finished:
            return

        logger.info(
            "data request failure, "
            "node_id: %s, obj_id: %s, offset: %s, error: %s", node_id, obj_id,
            offset, error)

        self.on_node_disconnected(node_id)

    def get_downloaded_chunks(self):
        if not self._downloaded_chunks:
            return None

        return self._downloaded_chunks

    def on_node_disconnected(self,
                             node_id,
                             connection_alive=False,
                             timeout_limit_exceed=True):
        requested_chunks = self._nodes_requested_chunks.pop(node_id, None)
        logger.info("node disconnected %s, chunks removed from requested: %s",
                    node_id, requested_chunks)
        if timeout_limit_exceed:
            self._nodes_available_chunks.pop(node_id, None)
            self._nodes_timeouts_count.pop(node_id, None)
            if connection_alive:
                self._connectivity_service.reconnect(node_id)
        self._nodes_last_receive_time.pop(node_id, None)
        self._nodes_downloaded_chunks_count.pop(node_id, None)

        if connection_alive:
            self.abort_data.emit(node_id, self.id, None)

        if self._nodes_available_chunks:
            self._download_chunks(check_node_busy=True)
        else:
            chunks_to_test = self._nodes_requested_chunks \
                if self._started and not self._paused \
                else self._nodes_available_chunks
            self._check_download_not_ready(chunks_to_test)

    def complete(self):
        if self._started and not self._finished:
            self._complete_download(force_complete=True)
        elif not self._finished:
            self._finished = True
            self.clean()
            self.download_complete.emit(self)

    def _download_chunks(self, check_node_busy=False):
        if not self._started or self._paused or self._finished:
            return

        logger.debug("download_chunks for %s", self.id)

        node_ids = list(self._nodes_available_chunks.keys())
        random.shuffle(node_ids)
        for node_id in node_ids:
            node_free = not check_node_busy or \
                        not self._nodes_requested_chunks.get(node_id, None)
            if node_free:
                self._download_next_chunks(node_id)
        self._clean_nodes_last_receive_time()
        self._check_download_not_ready(self._nodes_requested_chunks)

    def _check_can_receive(self, node_id):
        return True

    def _write_to_file(self, offset, data):
        self._file.seek(offset)
        try:
            self._file.write(data)
        except EnvironmentError as e:
            logger.error("Download task %s can't write to file. Reason: %s",
                         self.id, e)
            self._send_error_statistic()
            if e.errno == errno.ENOSPC:
                self._emit_no_disk_space(error=True)
            else:
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
            return False

        return True

    def _open_file(self, clean=False):
        if not self._file or self._file.closed:
            try:
                if clean:
                    self._file = open(self.download_path, 'wb')
                else:
                    self._file = open(self.download_path, 'r+b')
            except IOError:
                try:
                    self._file = open(self.download_path, 'wb')
                except IOError as e:
                    logger.error(
                        "Can't open file for download for task %s. "
                        "Reason: %s", self.id, e)
                    self.download_failed.emit(self)
                    return False

        return True

    def _close_file(self):
        if not self._file:
            return True

        try:
            self._file.close()
        except EnvironmentError as e:
            logger.error("Download task %s can't close file. Reason: %s",
                         self.id, e)
            self._send_error_statistic()
            if e.errno == errno.ENOSPC:
                self._emit_no_disk_space(error=True)
            else:
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
            self._file = None
            return False

        self._file = None
        return True

    def _write_info_file(self):
        try:
            self._info_file.seek(0)
            self._info_file.truncate()
            pickle.dump(self._downloaded_chunks, self._info_file,
                        pickle.HIGHEST_PROTOCOL)
            self._info_file.flush()
        except EnvironmentError as e:
            logger.debug("Can't write to info file for task id %s. Reason: %s",
                         self.id, e)

    def _read_info_file(self):
        try:
            if not self._info_file or self._info_file.closed:
                self._info_file = open(self._info_path, 'a+b')
                self._info_file.seek(0)
            try:
                self._downloaded_chunks = pickle.load(self._info_file)
            except:
                pass
        except EnvironmentError as e:
            logger.debug("Can't open info file for task id %s. Reason: %s",
                         self.id, e)

    def _close_info_file(self, to_remove=False):
        if not self._info_file:
            return

        try:
            self._info_file.close()
            if to_remove:
                remove_file(self._info_path)
        except Exception as e:
            logger.debug(
                "Can't close or remove info file "
                "for task id %s. Reason: %s", self.id, e)
        self._info_file = None

    def _complete_download(self, force_complete=False):
        if (not self._wanted_chunks or force_complete) and \
                not self._finished:
            logger.debug("download %s completed", self.id)
            self._nodes_requested_chunks.clear()
            for node_id in self._nodes_last_receive_time.keys():
                self.abort_data.emit(node_id, self.id, None)

            if not force_complete:
                self.download_finishing.emit()

            if not force_complete and self.file_hash:
                hash_check_result = self._check_file_hash()
                if hash_check_result is not None:
                    return hash_check_result

            self._started = False
            self._finished = True
            self.stop_download_chunks()
            self._close_info_file(to_remove=True)
            if not self._close_file():
                return False

            try:
                if force_complete:
                    remove_file(self.download_path)
                    self.download_complete.emit(self)
                else:
                    shutil.move(self.download_path, self.file_path)
                    self._send_end_statistic()
                    self.download_complete.emit(self)
                    if self.file_hash:
                        self.copy_added.emit(self.file_hash)
            except EnvironmentError as e:
                logger.error(
                    "Download task %s can't (re)move file. "
                    "Reason: %s", self.id, e)
                self._send_error_statistic()
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
                return False

            result = True
        else:
            result = not self._wanted_chunks
        return result

    def _check_file_hash(self):
        self._file.flush()
        try:
            hash = Rsync.hash_from_block_checksum(
                Rsync.block_checksum(self.download_path))
        except IOError as e:
            logger.error("download %s error: %s", self.id, e)
            hash = None
        if hash != self.file_hash:
            logger.error(
                "download hash check failed objId: %s, "
                "expected hash: %s, actual hash: %s", self.id, self.file_hash,
                hash)
            if not self._close_file() or not self._open_file(clean=True):
                return False

            self._downloaded_chunks.clear()
            self._nodes_downloaded_chunks_count.clear()
            self._nodes_last_receive_time.clear()
            self._nodes_timeouts_count.clear()
            self._write_info_file()
            self._init_wanted_chunks()

            self.received = 0
            if self._retry < self.retry_limit:
                self._retry += 1
                self.resume()
            else:
                self._retry = 0
                self._nodes_available_chunks.clear()
                self.hash_is_wrong = True
                self.wrong_hash.emit(self)
            return True

        return None

    def _download_next_chunks(self, node_id, time_from_last_received_chunk=0.):
        if (self._paused or not self._started or not self._ready
                or self._finished or not self._wanted_chunks
                or self._leaky_timer.isActive()):
            return

        total_requested = sum(
            map(lambda x: sum(x.values()),
                self._nodes_requested_chunks.values()))

        if total_requested + self.received >= self.size:
            if self._nodes_requested_chunks.get(node_id, None) and \
                    time_from_last_received_chunk <= self.end_race_timeout:
                return

            available_chunks = \
                self._get_end_race_chunks_to_download_from_node(node_id)
        else:
            available_chunks = \
                self._get_available_chunks_to_download_from_node(node_id)

        if not available_chunks:
            logger.debug("no chunks available for download %s", self.id)
            logger.debug("downloading from: %s nodes, length: %s, wanted: %s",
                         len(self._nodes_requested_chunks), total_requested,
                         self.size - self.received)
            return

        available_offset = random.sample(available_chunks.keys(), 1)[0]
        available_length = available_chunks[available_offset]
        logger.debug("selected random offset: %s", available_offset)

        parts_count = math.ceil(
            float(available_length) / float(DOWNLOAD_PART_SIZE)) - 1
        logger.debug("parts count: %s", parts_count)

        part_to_download_number = random.randint(0, parts_count)
        offset = available_offset + \
                 part_to_download_number * DOWNLOAD_PART_SIZE
        length = min(DOWNLOAD_PART_SIZE,
                     available_offset + available_length - offset)
        logger.debug("selected random part: %s, offset: %s, length: %s",
                     part_to_download_number, offset, length)

        self._request_data(node_id, offset, length)

    def _get_end_race_chunks_to_download_from_node(self, node_id):
        available_chunks = self._nodes_available_chunks.get(node_id, None)
        if not available_chunks:
            return []

        available_chunks = available_chunks.copy()
        logger.debug("end race downloaded_chunks: %s", self._downloaded_chunks)
        logger.debug("end race requested_chunks: %s",
                     self._nodes_requested_chunks)
        logger.debug("end race available_chunks before excludes: %s",
                     available_chunks)
        if self._downloaded_chunks:
            for downloaded_chunk in self._downloaded_chunks.items():
                self._remove_from_chunks(downloaded_chunk[0],
                                         downloaded_chunk[1], available_chunks)
        if not available_chunks:
            return []

        available_from_other_nodes = available_chunks.copy()
        for requested_offset, requested_length in \
                self._nodes_requested_chunks.get(node_id, dict()).items():
            self._remove_from_chunks(requested_offset, requested_length,
                                     available_from_other_nodes)

        result = available_from_other_nodes if available_from_other_nodes \
            else available_chunks

        if result:
            logger.debug("end race available_chunks after excludes: %s",
                         available_chunks)
        return result

    def _get_available_chunks_to_download_from_node(self, node_id):
        available_chunks = self._nodes_available_chunks.get(node_id, None)
        if not available_chunks:
            return []

        available_chunks = available_chunks.copy()
        logger.debug("downloaded_chunks: %s", self._downloaded_chunks)
        logger.debug("requested_chunks: %s", self._nodes_requested_chunks)
        logger.debug("available_chunks before excludes: %s", available_chunks)
        for _, requested_chunks in self._nodes_requested_chunks.items():
            for requested_offset, requested_length in requested_chunks.items():
                self._remove_from_chunks(requested_offset, requested_length,
                                         available_chunks)
        if not available_chunks:
            return []

        for downloaded_chunk in self._downloaded_chunks.items():
            self._remove_from_chunks(downloaded_chunk[0], downloaded_chunk[1],
                                     available_chunks)
        logger.debug("available_chunks after excludes: %s", available_chunks)
        return available_chunks

    def _request_data(self, node_id, offset, length):
        logger.debug("Requesting date from node %s, request_chunk (%s, %s)",
                     node_id, offset, length)
        if self._limiter:
            try:
                self._limiter.leak(length)
            except LeakyBucketException:
                if node_id not in self._nodes_requested_chunks:
                    self._nodes_last_receive_time.pop(node_id, None)
                    if not self._network_limited_error_set:
                        self.download_error.emit('Network limited.')
                        self._network_limited_error_set = True
                if not self._leaky_timer.isActive():
                    self._leaky_timer.start()
                return

        if self._network_limited_error_set:
            self._network_limited_error_set = False
            self.download_ok.emit()

        requested_chunks = self._nodes_requested_chunks.get(node_id, None)
        if not requested_chunks:
            requested_chunks = SortedDict()
            self._nodes_requested_chunks[node_id] = requested_chunks
        requested_chunks[offset] = length
        logger.debug("Requested chunks %s", requested_chunks)
        self._nodes_last_receive_time[node_id] = time()
        self.request_data.emit(node_id, self.id, str(offset), length)

    def _clean_nodes_last_receive_time(self):
        for node_id in list(self._nodes_last_receive_time.keys()):
            if node_id not in self._nodes_requested_chunks:
                self._nodes_last_receive_time.pop(node_id, None)

    def _on_check_timeouts(self):
        if self._paused or not self._started \
                or self._finished or self._leaky_timer.isActive():
            return

        timed_out_nodes = set()
        cur_time = time()
        logger.debug("Chunk requests check %s",
                     len(self._nodes_requested_chunks))
        if self._check_download_not_ready(self._nodes_requested_chunks):
            return

        for node_id in self._nodes_last_receive_time:
            last_receive_time = self._nodes_last_receive_time.get(node_id)
            if cur_time - last_receive_time > self.receive_timeout:
                timed_out_nodes.add(node_id)

        logger.debug("Timed out nodes %s, nodes last receive time %s",
                     timed_out_nodes, self._nodes_last_receive_time)
        for node_id in timed_out_nodes:
            timeout_count = self._nodes_timeouts_count.pop(node_id, 0)
            timeout_count += 1
            if timeout_count >= self.timeouts_limit:
                retry = False
            else:
                retry = True
                self._nodes_timeouts_count[node_id] = timeout_count
            logger.debug("Node if %s, timeout_count %s, retry %s", node_id,
                         timeout_count, retry)
            self.on_node_disconnected(node_id,
                                      connection_alive=True,
                                      timeout_limit_exceed=not retry)

    def _get_chunks_from_info(self, chunks, info):
        new_added = False
        for part_info in info:
            logger.debug("get_chunks_from_info part_info %s", part_info)
            if part_info.length == 0:
                continue

            if not chunks:
                chunks[part_info.offset] = part_info.length
                new_added = True
                continue

            result_offset = part_info.offset
            result_length = part_info.length
            left_index = chunks.bisect_right(part_info.offset)
            if left_index > 0:
                left_chunk = chunks.peekitem(left_index - 1)
                if (left_chunk[0] <= part_info.offset
                        and left_chunk[0] + left_chunk[1] >=
                        part_info.offset + part_info.length):
                    continue

                if part_info.offset <= left_chunk[0] + left_chunk[1]:
                    result_offset = left_chunk[0]
                    result_length = part_info.offset + \
                                    part_info.length - result_offset
                    left_index -= 1

            right_index = chunks.bisect_right(part_info.offset +
                                              part_info.length)
            if right_index > 0:
                right_chunk = chunks.peekitem(right_index - 1)
                if part_info.offset + part_info.length <= \
                        right_chunk[0] + right_chunk[1]:
                    result_length = right_chunk[0] + \
                                    right_chunk[1] - result_offset

            to_delete = list(chunks.islice(left_index, right_index))

            for to_del in to_delete:
                chunks.pop(to_del)

            new_added = True
            chunks[result_offset] = result_length

        return new_added

    def _store_availability_info(self, node_id, info):
        known_chunks = self._nodes_available_chunks.get(node_id, None)
        if not known_chunks:
            known_chunks = SortedDict()
            self._nodes_available_chunks[node_id] = known_chunks
        return self._get_chunks_from_info(known_chunks, info)

    def _check_download_not_ready(self, checkable):
        if not self._wanted_chunks and self._started:
            self._complete_download(force_complete=False)
            return False

        if self._leaky_timer.isActive():
            if not self._nodes_available_chunks:
                self._make_not_ready()
                return True

        elif not checkable:
            self._make_not_ready()
            return True

        return False

    def _make_not_ready(self):
        if not self._ready:
            return

        logger.info("download %s not ready now", self.id)
        self._ready = False
        self._started = False
        if self._timeout_timer.isActive():
            self._timeout_timer.stop()
        if self._leaky_timer.isActive():
            self._leaky_timer.stop()
        self.download_not_ready.emit(self)

    def _clear_globals(self):
        self._wanted_chunks.clear()
        self._downloaded_chunks.clear()
        self._nodes_available_chunks.clear()
        self._nodes_requested_chunks.clear()
        self._nodes_last_receive_time.clear()
        self._nodes_downloaded_chunks_count.clear()
        self._nodes_timeouts_count.clear()
        self._total_chunks_count = 0

    def stop_download_chunks(self):
        if self._leaky_timer.isActive():
            self._leaky_timer.stop()
        if self._timeout_timer.isActive():
            self._timeout_timer.stop()

        for node_id in self._nodes_requested_chunks:
            self.abort_data.emit(node_id, self.id, None)

        self._nodes_requested_chunks.clear()
        self._nodes_last_receive_time.clear()

    def _emit_no_disk_space(self, error=False):
        self._no_disk_space_error = True
        self._nodes_available_chunks.clear()
        self._clear_globals()
        self._make_not_ready()
        file_name = self.display_name.split()[-1] \
            if self.display_name else ""
        self.no_disk_space.emit(self, file_name, error)

    def _send_start_statistic(self):
        if self._tracker:
            self._tracker.download_start(self.id, self.size)

    def _send_end_statistic(self):
        if self._tracker:
            time_diff = time() - self._started_time
            if time_diff < 1e-3:
                time_diff = 1e-3

            self._tracker.download_end(
                self.id,
                time_diff,
                websockets_bytes=0,
                webrtc_direct_bytes=self._received_via_p2p,
                webrtc_relay_bytes=self._received_via_turn,
                chunks=len(self._downloaded_chunks),
                chunks_reloaded=0,
                nodes=len(self._nodes_available_chunks))

    def _send_error_statistic(self):
        if self._tracker:
            time_diff = time() - self._started_time
            if time_diff < 1e-3:
                time_diff = 1e-3

            self._tracker.download_error(
                self.id,
                time_diff,
                websockets_bytes=0,
                webrtc_direct_bytes=self._received_via_p2p,
                webrtc_relay_bytes=self._received_via_turn,
                chunks=len(self._downloaded_chunks),
                chunks_reloaded=0,
                nodes=len(self._nodes_available_chunks))
示例#25
0
class FederationRemoteSendQueue(AbstractFederationSender):
    """A drop in replacement for FederationSender"""
    def __init__(self, hs: "HomeServer"):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        # We may have multiple federation sender instances, so we need to track
        # their positions separately.
        self._sender_instances = hs.config.worker.federation_shard_config.instances
        self._sender_positions = {}  # type: Dict[str, int]

        # Pending presence map user_id -> UserPresenceState
        self.presence_map = {}  # type: Dict[str, UserPresenceState]

        # Stores the destinations we need to explicitly send presence to about a
        # given user.
        # Stream position -> (user_id, destinations)
        self.presence_destinations = (
            SortedDict())  # type: SortedDict[int, Tuple[str, Iterable[str]]]

        # (destination, key) -> EDU
        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]

        # stream position -> (destination, key)
        self.keyed_edu_changed = (SortedDict()
                                  )  # type: SortedDict[int, Tuple[str, tuple]]

        self.edus = SortedDict()  # type: SortedDict[int, Edu]

        # stream ID for the next entry into keyed_edu_changed/edus.
        self.pos = 1

        # map from stream ID to the time that stream entry was generated, so that we
        # can clear out entries after a while
        self.pos_time = SortedDict()  # type: SortedDict[int, int]

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name: str, queue: Sized) -> None:
            LaterGauge(
                "synapse_federation_send_queue_%s_size" % (queue_name, ),
                "",
                [],
                lambda: len(queue),
            )

        for queue_name in [
                "presence_map",
                "keyed_edu",
                "keyed_edu_changed",
                "edus",
                "pos_time",
                "presence_destinations",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self) -> int:
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self) -> None:
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete: int) -> None:
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_destinations.keys()
            i = self.presence_destinations.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_destinations[key]

            user_ids = {
                user_id
                for user_id, _ in self.presence_destinations.values()
            }

            to_del = [
                user_id for user_id in self.presence_map
                if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            keys_to_del = [
                edu_key for edu_key in self.keyed_edu
                if edu_key not in live_keys
            ]
            for edu_key in keys_to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

    def notify_new_events(self, max_token: RoomStreamToken) -> None:
        """As per FederationSender"""
        # This should never get called.
        raise NotImplementedError()

    def build_and_send_edu(
        self,
        destination: str,
        edu_type: str,
        content: JsonDict,
        key: Optional[Hashable] = None,
    ) -> None:
        """As per FederationSender"""
        if destination == self.server_name:
            logger.info("Not sending EDU to ourselves")
            return

        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    async def send_read_receipt(self, receipt: ReadReceipt) -> None:
        """As per FederationSender

        Args:
            receipt:
        """
        # nothing to do here: the replication listener will handle it.

    def send_presence_to_destinations(self,
                                      states: Iterable[UserPresenceState],
                                      destinations: Iterable[str]) -> None:
        """As per FederationSender

        Args:
            states
            destinations
        """
        for state in states:
            pos = self._next_pos()
            self.presence_map.update(
                {state.user_id: state
                 for state in states})
            self.presence_destinations[pos] = (state.user_id, destinations)

        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination: str) -> None:
        """As per FederationSender"""
        # We don't need to replicate this as it gets sent down a different
        # stream.

    def wake_destination(self, server: str) -> None:
        pass

    def get_current_token(self) -> int:
        return self.pos - 1

    def federation_ack(self, instance_name: str, token: int) -> None:
        if self._sender_instances:
            # If we have configured multiple federation sender instances we need
            # to track their positions separately, and only clear the queue up
            # to the token all instances have acked.
            self._sender_positions[instance_name] = token
            token = min(self._sender_positions.values())

        self._clear_queue_before_pos(token)

    async def get_replication_rows(
            self, instance_name: str, from_token: int, to_token: int,
            target_row_count: int
    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
        """Get rows to be sent over federation between the two tokens

        Args:
            instance_name: the name of the current process
            from_token: the previous stream token: the starting point for fetching the
                updates
            to_token: the new stream token: the point to get updates up to
            target_row_count: a target for the number of rows to be returned.

        Returns: a triplet `(updates, new_last_token, limited)`, where:
           * `updates` is a list of `(token, row)` entries.
           * `new_last_token` is the new position in stream.
           * `limited` is whether there are more updates to fetch.
        """
        # TODO: Handle target_row_count.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []  # type: List[Tuple[int, BaseFederationRow]]

        # Fetch presence to send to destinations
        i = self.presence_destinations.bisect_right(from_token)
        j = self.presence_destinations.bisect_right(to_token) + 1

        for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
            rows.append((
                pos,
                PresenceDestinationsRow(state=self.presence_map[user_id],
                                        destinations=list(dests)),
            ))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in keyed_edus.items():
            rows.append((
                pos,
                KeyedEduRow(key=edu_key,
                            edu=self.keyed_edu[(destination, edu_key)]),
            ))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Sort rows based on pos
        rows.sort()

        return (
            [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
            to_token,
            False,
        )
示例#26
0
class StepVector():
    @classmethod
    def sliced(cls, other, start, end):
        newobj = cls(other.datatype, _tree=other._t, _bounds=(start, end))
        return newobj

    def __init__(self, datatype, _tree=None, _bounds=None):
        self.datatype = datatype

        if _tree is not None:
            self._t = _tree
        else:
            self._t = SortedDict()

        if _bounds is not None:
            self._bounds = _bounds
        else:
            self._bounds = (None, None)  # set upon slicing/subsetting

    def __getitem__(self, key):
        if type(key) == slice:
            if (key.step is not None) and (key.step != 1):
                raise ValueError("Invalid step value")

            start = key.start
            end = key.stop

            if self._bounds[0] is not None:
                if start is None:
                    start = self._bounds[0]
                else:
                    if start < self._bounds[0]:
                        raise ValueError("Start out of bounds")
            if self._bounds[1] is not None:
                if end is None:
                    end = self._bounds[1]
                else:
                    if end > self._bounds[1]:
                        raise ValueError("End out of bounds")

            return self.sliced(self, start, end)
        else:
            assert type(key) == int

            if self._bounds[0] is not None:
                if key < self._bounds[0]:
                    raise ValueError("Key out of bounds")
            if self._bounds[1] is not None:
                if key >= self._bounds[0]:
                    raise ValueError("Key out of bounds")

            if self._t:
                try:
                    prevkey = self._floor_key(key)
                    return self._t[prevkey]
                except KeyError:
                    # no item smaller than or equal to key
                    return self.datatype()
            else:
                # empty tree
                return self.datatype()

    def __setitem__(self, key, value):
        if type(key) == slice:
            start = key.start
            end = key.stop
        else:
            assert type(key) == int
            start = key
            end = key + 1

        assert start is not None
        assert end is not None

        assert type(value) == self.datatype
        assert end >= start

        if start == end:
            return

        # check next val
        if self._t:
            try:
                nkey = self._floor_key(end, bisect="right")
                nvalue = self._t[nkey]
            except KeyError:
                nkey = None
                nvalue = None
        else:
            # empty tree
            nkey = None
            nvalue = None

        # check prev val
        if self._t:
            try:
                pkey = self._floor_key(start)
                pvalue = self._t[pkey]
            except KeyError:
                pkey = None
                pvalue = None
        else:
            pkey = None
            pvalue = None

        # remove intermediate steps if any
        if self._t:
            a = self._t.bisect_left(start)
            b = self._t.bisect(end)
            assert a <= b
            del self._t.iloc[a:b]

        # set an end marker if necessary
        if nkey is None:
            self._t[end] = self.datatype()
        elif nvalue != value:
            self._t[end] = nvalue

        # set a start marker if necessary
        if pkey is None or pvalue != value:
            self._t[start] = value

    def __iter__(self):
        start, end = self._bounds

        if not self._t:
            # empty tree
            if start is None or end is None:
                raise StopIteration  # FIXME: can't figure out a better thing to do if only one is None
            else:
                if start < end:
                    yield (start, end, self.datatype())
                raise StopIteration

        if start is None:
            a = 0
        else:
            a = max(0, self._bisect_right(start) - 1)

        if end is None:
            b = len(self._t)
        else:
            b = self._bisect_right(end)

        assert b >= a
        if a == b:
            if a is None:
                start = self._t[a]
            if b is None:
                end = self._t[b]

            if start < end:
                yield (start, end, self.datatype())

            raise StopIteration

        it = self._t.islice(a, b)

        currkey = next(it)
        currvalue = self._t[currkey]
        if start is not None:
            currkey = max(start, currkey)
            if start < currkey:
                yield (start, currkey, self.datatype())

        prevkey, prevvalue = currkey, currvalue
        for currkey in it:
            currvalue = self._t[currkey]
            yield (prevkey, currkey, prevvalue)
            prevkey = currkey
            prevvalue = currvalue

        if end is not None:
            if currkey < end:
                yield (currkey, end, prevvalue)

    def add_value(self, start, end, value):
        assert type(value) == self.datatype

        # can't modify self while iterating over values; will change the tree, and thus f**k up iteration
        items = list(self[start:end])

        for a, b, x in items:
            if self.datatype == set:
                y = x.copy()
                y.update(value)
            else:
                y = x + value

            self[a:b] = y

    def _bisect_left(self, key):
        return self._t.bisect_left(key)

    def _bisect_right(self, key):
        return self._t.bisect_right(key)

    def _floor_key(self, key, bisect="left"):
        """
        Returns the greatest key less than or equal to key
        """

        if bisect == "right":
            p = self._bisect_right(key)
        else:
            p = self._bisect_left(key)

        if p == 0:
            raise KeyError
        else:
            return self._t.iloc[p - 1]
示例#27
0
class StreamChangeCache(object):
    """Keeps track of the stream positions of the latest change in a set of entities.

    Typically the entity will be a room or user id.

    Given a list of entities and a stream position, it will give a subset of
    entities that may have changed since that position. If position key is too
    old then the cache will simply return all given entities.
    """

    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
        self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
        self._entity_to_key = {}
        self._cache = SortedDict()
        self._earliest_known_stream_pos = current_stream_pos
        self.name = name
        self.metrics = caches.register_cache("cache", self.name, self._cache)

        if prefilled_cache:
            for entity, stream_pos in prefilled_cache.items():
                self.entity_has_changed(entity, stream_pos)

    def has_entity_changed(self, entity, stream_pos):
        """Returns True if the entity may have been updated since stream_pos
        """
        assert type(stream_pos) in integer_types

        if stream_pos < self._earliest_known_stream_pos:
            self.metrics.inc_misses()
            return True

        latest_entity_change_pos = self._entity_to_key.get(entity, None)
        if latest_entity_change_pos is None:
            self.metrics.inc_hits()
            return False

        if stream_pos < latest_entity_change_pos:
            self.metrics.inc_misses()
            return True

        self.metrics.inc_hits()
        return False

    def get_entities_changed(self, entities, stream_pos):
        """
        Returns subset of entities that have had new things since the given
        position.  Entities unknown to the cache will be returned.  If the
        position is too old it will just return the given list.
        """
        assert type(stream_pos) is int

        if stream_pos >= self._earliest_known_stream_pos:
            changed_entities = {
                self._cache[k] for k in self._cache.islice(
                    start=self._cache.bisect_right(stream_pos),
                )
            }

            result = changed_entities.intersection(entities)

            self.metrics.inc_hits()
        else:
            result = set(entities)
            self.metrics.inc_misses()

        return result

    def has_any_entity_changed(self, stream_pos):
        """Returns if any entity has changed
        """
        assert type(stream_pos) is int

        if not self._cache:
            # If we have no cache, nothing can have changed.
            return False

        if stream_pos >= self._earliest_known_stream_pos:
            self.metrics.inc_hits()
            return self._cache.bisect_right(stream_pos) < len(self._cache)
        else:
            self.metrics.inc_misses()
            return True

    def get_all_entities_changed(self, stream_pos):
        """Returns all entites that have had new things since the given
        position. If the position is too old it will return None.
        """
        assert type(stream_pos) is int

        if stream_pos >= self._earliest_known_stream_pos:
            return [self._cache[k] for k in self._cache.islice(
                start=self._cache.bisect_right(stream_pos))]
        else:
            return None

    def entity_has_changed(self, entity, stream_pos):
        """Informs the cache that the entity has been changed at the given
        position.
        """
        assert type(stream_pos) is int

        if stream_pos > self._earliest_known_stream_pos:
            old_pos = self._entity_to_key.get(entity, None)
            if old_pos is not None:
                stream_pos = max(stream_pos, old_pos)
                self._cache.pop(old_pos, None)
            self._cache[stream_pos] = entity
            self._entity_to_key[entity] = stream_pos

            while len(self._cache) > self._max_size:
                k, r = self._cache.popitem(0)
                self._earliest_known_stream_pos = max(
                    k, self._earliest_known_stream_pos,
                )
                self._entity_to_key.pop(r, None)

    def get_max_pos_of_last_change(self, entity):
        """Returns an upper bound of the stream id of the last change to an
        entity.
        """
        return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
示例#28
0
class BaseColorCodePatchBuilder(ASAxesPatchBuilder, PickablePatchBuilder):
    """
    The patch generator build the matplotlib patches for each
    capability node.

    The nodes are rendered as lines with a different color depending
    on the permission bits of the capability. The builder produces
    a LineCollection for each combination of permission bits and
    creates the lines for the nodes.
    """
    def __init__(self, figure, pgm):
        """
        Constructor

        :param figure: the figure to attache the click callback
        :param pgm: the provenance graph model
        """
        super().__init__(figure=figure)

        self._pgm = pgm
        """The provenance graph model"""

        self._collection_map = defaultdict(lambda: [])
        """
        Map capability permission to the set where the line should go.
        Any combination of capability permissions is used as key for
        a list of (start, end) values that are used to build LineCollections.
        The key "call" is used for system call nodes, the int(0) key is used
        for no permission.
        """

        self._colors = {}
        """
        Map capability permission to line colors.
        XXX: keep this for now, move to a colormap
        """

        self._bbox = [np.inf, np.inf, 0, 0]
        """Bounding box of the patches as (xmin, ymin, xmax, ymax)."""

        self._node_map = SortedDict()
        """Maps the Y axis coordinate to the graph node at that position"""

    def _clickable_element(self, vertex, y):
        """remember the node at the given Y for faster indexing."""
        data = self._pgm.data[vertex]
        self._node_map[y] = data

    def _add_bbox(self, xmin, xmax, y):
        """Update the view bbox."""
        if self._bbox[0] > xmin:
            self._bbox[0] = xmin
        if self._bbox[1] > y:
            self._bbox[1] = y
        if self._bbox[2] < xmax:
            self._bbox[2] = xmax
        if self._bbox[3] < y:
            self._bbox[3] = y

    def _get_patch_collections(self, axes):
        """Return a generator of collections of patches to add to the axes."""
        pass

    def get_patches(self, axes):
        """
        Return a collection of lines from the collection_map.
        """
        super().get_patches(axes)
        for coll in self._get_patch_collections(axes):
            axes.add_collection(coll)

    def get_bbox(self):
        return Bbox.from_extents(*self._bbox)

    def on_click(self, event):
        """
        Attempt to retreive the data in less than O(n) for better
        interactivity at the expense of having to hold a dictionary of
        references to nodes for each t_alloc.
        Note that t_alloc is unique for each capability node as it
        is the cycle count, so it can be used as the key.
        """
        ax = event.inaxes
        if ax is None:
            return

        # back to data coords without scaling
        y_coord = int(event.ydata)
        y_max = self._bbox[3]
        # tolerance for y distance, 0.1 * 10^6 cycles
        epsilon = 0.1 * 10**6

        # try to get the node closer to the y_coord
        # in the fast way
        # For now fall-back to a reduced linear search but would be
        # useful to be able to index lines with an R-tree?
        idx_min = self._node_map.bisect_left(max(0, y_coord - epsilon))
        idx_max = self._node_map.bisect_right(min(y_max, y_coord + epsilon))
        iter_keys = self._node_map.islice(idx_min, idx_max)
        # find the closest node to the click position
        pick_target = None
        for key in iter_keys:
            node = self._node_map[key]
            if (node.cap.base <= event.xdata
                    and node.cap.bound >= event.xdata):
                # the click event is within the node bounds and
                # the node Y is closer to the click event than
                # the previous pick_target
                if (pick_target is None or abs(y_coord - key) <
                        abs(y_coord - pick_target.cap.t_alloc)):
                    pick_target = node
        if pick_target is not None:
            ax.set_status_message(pick_target)
        else:
            ax.set_status_message("")
def test_bisect_key():
    temp = SortedDict(modulo, 7, ((val, val) for val in range(100)))
    assert all(temp.bisect(val) == ((val % 10) + 1) * 10 for val in range(100))
    assert all(temp.bisect_right(val) == ((val % 10) + 1) * 10 for val in range(100))
    assert all(temp.bisect_left(val) == (val % 10) * 10 for val in range(100))
示例#30
0
class StreamChangeCache:
    """Keeps track of the stream positions of the latest change in a set of entities.

    Typically the entity will be a room or user id.

    Given a list of entities and a stream position, it will give a subset of
    entities that may have changed since that position. If position key is too
    old then the cache will simply return all given entities.
    """
    def __init__(
        self,
        name: str,
        current_stream_pos: int,
        max_size=10000,
        prefilled_cache: Optional[Mapping[EntityType, int]] = None,
    ):
        self._original_max_size = max_size
        self._max_size = math.floor(max_size)
        self._entity_to_key = {}  # type: Dict[EntityType, int]

        # map from stream id to the a set of entities which changed at that stream id.
        self._cache = SortedDict()  # type: SortedDict[int, Set[EntityType]]

        # the earliest stream_pos for which we can reliably answer
        # get_all_entities_changed. In other words, one less than the earliest
        # stream_pos for which we know _cache is valid.
        #
        self._earliest_known_stream_pos = current_stream_pos
        self.name = name
        self.metrics = caches.register_cache(
            "cache",
            self.name,
            self._cache,
            resize_callback=self.set_cache_factor)

        if prefilled_cache:
            for entity, stream_pos in prefilled_cache.items():
                self.entity_has_changed(entity, stream_pos)

    def set_cache_factor(self, factor: float) -> bool:
        """
        Set the cache factor for this individual cache.

        This will trigger a resize if it changes, which may require evicting
        items from the cache.

        Returns:
            bool: Whether the cache changed size or not.
        """
        new_size = math.floor(self._original_max_size * factor)
        if new_size != self._max_size:
            self.max_size = new_size
            self._evict()
            return True
        return False

    def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool:
        """Returns True if the entity may have been updated since stream_pos
        """
        assert type(stream_pos) in integer_types

        if stream_pos < self._earliest_known_stream_pos:
            self.metrics.inc_misses()
            return True

        latest_entity_change_pos = self._entity_to_key.get(entity, None)
        if latest_entity_change_pos is None:
            self.metrics.inc_hits()
            return False

        if stream_pos < latest_entity_change_pos:
            self.metrics.inc_misses()
            return True

        self.metrics.inc_hits()
        return False

    def get_entities_changed(
            self, entities: Collection[EntityType],
            stream_pos: int) -> Union[Set[EntityType], FrozenSet[EntityType]]:
        """
        Returns subset of entities that have had new things since the given
        position.  Entities unknown to the cache will be returned.  If the
        position is too old it will just return the given list.
        """
        changed_entities = self.get_all_entities_changed(stream_pos)
        if changed_entities is not None:
            # We now do an intersection, trying to do so in the most efficient
            # way possible (some of these sets are *large*). First check in the
            # given iterable is already set that we can reuse, otherwise we
            # create a set of the *smallest* of the two iterables and call
            # `intersection(..)` on it (this can be twice as fast as the reverse).
            if isinstance(entities, (set, frozenset)):
                result = entities.intersection(changed_entities)
            elif len(changed_entities) < len(entities):
                result = set(changed_entities).intersection(entities)
            else:
                result = set(entities).intersection(changed_entities)
            self.metrics.inc_hits()
        else:
            result = set(entities)
            self.metrics.inc_misses()

        return result

    def has_any_entity_changed(self, stream_pos: int) -> bool:
        """Returns if any entity has changed
        """
        assert type(stream_pos) is int

        if not self._cache:
            # If the cache is empty, nothing can have changed.
            return False

        if stream_pos >= self._earliest_known_stream_pos:
            self.metrics.inc_hits()
            return self._cache.bisect_right(stream_pos) < len(self._cache)
        else:
            self.metrics.inc_misses()
            return True

    def get_all_entities_changed(
            self, stream_pos: int) -> Optional[List[EntityType]]:
        """Returns all entities that have had new things since the given
        position. If the position is too old it will return None.

        Returns the entities in the order that they were changed.
        """
        assert type(stream_pos) is int

        if stream_pos < self._earliest_known_stream_pos:
            return None

        changed_entities = []  # type: List[EntityType]

        for k in self._cache.islice(
                start=self._cache.bisect_right(stream_pos)):
            changed_entities.extend(self._cache[k])
        return changed_entities

    def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
        """Informs the cache that the entity has been changed at the given
        position.
        """
        assert type(stream_pos) is int

        if stream_pos <= self._earliest_known_stream_pos:
            return

        old_pos = self._entity_to_key.get(entity, None)
        if old_pos is not None:
            if old_pos >= stream_pos:
                # nothing to do
                return
            e = self._cache[old_pos]
            e.remove(entity)
            if not e:
                # cache at this point is now empty
                del self._cache[old_pos]

        e1 = self._cache.get(stream_pos)
        if e1 is None:
            e1 = self._cache[stream_pos] = set()
        e1.add(entity)
        self._entity_to_key[entity] = stream_pos
        self._evict()

        # if the cache is too big, remove entries
        while len(self._cache) > self._max_size:
            k, r = self._cache.popitem(0)
            self._earliest_known_stream_pos = max(
                k, self._earliest_known_stream_pos)
            for entity in r:
                del self._entity_to_key[entity]

    def _evict(self):
        while len(self._cache) > self._max_size:
            k, r = self._cache.popitem(0)
            self._earliest_known_stream_pos = max(
                k, self._earliest_known_stream_pos)
            for entity in r:
                self._entity_to_key.pop(entity, None)

    def get_max_pos_of_last_change(self, entity: EntityType) -> int:
        """Returns an upper bound of the stream id of the last change to an
        entity.
        """
        return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
示例#31
0
class FederationRemoteSendQueue(object):
    """A drop in replacement for TransactionQueue"""
    def __init__(self, hs):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        self.presence_map = {
        }  # Pending presence map user_id -> UserPresenceState
        self.presence_changed = SortedDict()  # Stream position -> user_id

        self.keyed_edu = {}  # (destination, key) -> EDU
        self.keyed_edu_changed = SortedDict(
        )  # stream position -> (destination, key)

        self.edus = SortedDict()  # stream position -> Edu

        self.device_messages = SortedDict()  # stream position -> destination

        self.pos = 1
        self.pos_time = SortedDict()

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name, queue):
            LaterGauge(
                "synapse_federation_send_queue_%s_size" % (queue_name, ), "",
                [], lambda: len(queue))

        for queue_name in [
                "presence_map",
                "presence_changed",
                "keyed_edu",
                "keyed_edu_changed",
                "edus",
                "device_messages",
                "pos_time",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self):
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self):
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete):
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_changed.keys()
            i = self.presence_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_changed[key]

            user_ids = set(user_id
                           for uids in itervalues(self.presence_changed)
                           for user_id in uids)

            to_del = [
                user_id for user_id in self.presence_map
                if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            to_del = [
                edu_key for edu_key in self.keyed_edu
                if edu_key not in live_keys
            ]
            for edu_key in to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

            # Delete things out of device map
            keys = self.device_messages.keys()
            i = self.device_messages.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.device_messages[key]

    def notify_new_events(self, current_id):
        """As per TransactionQueue"""
        # We don't need to replicate this as it gets sent down a different
        # stream.
        pass

    def send_edu(self, destination, edu_type, content, key=None):
        """As per TransactionQueue"""
        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    def send_presence(self, states):
        """As per TransactionQueue

        Args:
            states (list(UserPresenceState))
        """
        pos = self._next_pos()

        # We only want to send presence for our own users, so lets always just
        # filter here just in case.
        local_states = list(
            filter(lambda s: self.is_mine_id(s.user_id), states))

        self.presence_map.update(
            {state.user_id: state
             for state in local_states})
        self.presence_changed[pos] = [state.user_id for state in local_states]

        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination):
        """As per TransactionQueue"""
        pos = self._next_pos()
        self.device_messages[pos] = destination
        self.notifier.on_new_replication_data()

    def get_current_token(self):
        return self.pos - 1

    def federation_ack(self, token):
        self._clear_queue_before_pos(token)

    def get_replication_rows(self,
                             from_token,
                             to_token,
                             limit,
                             federation_ack=None):
        """Get rows to be sent over federation between the two tokens

        Args:
            from_token (int)
            to_token(int)
            limit (int)
            federation_ack (int): Optional. The position where the worker is
                explicitly acknowledged it has handled. Allows us to drop
                data from before that point
        """
        # TODO: Handle limit.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []

        # There should be only one reader, so lets delete everything its
        # acknowledged its seen.
        if federation_ack:
            self._clear_queue_before_pos(federation_ack)

        # Fetch changed presence
        i = self.presence_changed.bisect_right(from_token)
        j = self.presence_changed.bisect_right(to_token) + 1
        dest_user_ids = [
            (pos, user_id)
            for pos, user_id_list in self.presence_changed.items()[i:j]
            for user_id in user_id_list
        ]

        for (key, user_id) in dest_user_ids:
            rows.append((key, PresenceRow(state=self.presence_map[user_id], )))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in iteritems(keyed_edus):
            rows.append((pos,
                         KeyedEduRow(
                             key=edu_key,
                             edu=self.keyed_edu[(destination, edu_key)],
                         )))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Fetch changed device messages
        i = self.device_messages.bisect_right(from_token)
        j = self.device_messages.bisect_right(to_token) + 1
        device_messages = {v: k for k, v in self.device_messages.items()[i:j]}

        for (destination, pos) in iteritems(device_messages):
            rows.append((pos, DeviceRow(destination=destination, )))

        # Sort rows based on pos
        rows.sort()

        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
示例#32
0
def test6():
    """
    有序的map: SortedDict
    网址: http://www.grantjenks.com/docs/sortedcontainers/sorteddict.html
    """
    from sortedcontainers import SortedDict
    sd = SortedDict()
    # 插入、删除元素
    sd["wxx"] = 21
    sd["hh"] = 18
    sd["other"] = 20
    print(sd)  # SortedDict({'hh': 18, 'other': 20, 'wxx': 21})
    print(sd["wxx"])  # 访问不存在的键会报错, KeyError
    print(sd.get("c"))  # 访问不存在的键会返回None     None
    # SortedDict转dict
    print(dict(sd))  # {'hh': 18, 'other': 20, 'wxx': 21}
    # 返回最后一个元素和最后一个元素
    print(sd.peekitem(0))  # 类型tuple, 返回第一个元素    ('hh', 18)
    print(sd.peekitem())  # 类型tuple, 返回最后一个元素    ('wxx', 21)
    # 遍历
    for k, v in sd.items():
        print(k, ':', v, sep="", end=", ")  # sep取消每行输出之间的空格
    print()
    for k in sd:  # 遍历键k, 等价于for k in d.keys:
        print(str(k) + ":" + str(sd[k]), end=", ")
    print()
    for v in sd.values():  # 遍历值v
        print(v, end=", ")
    print()
    # 返回Map中的一个键
    print(sd.peekitem()[0])
    # 返回Map中的一个值
    print(sd.peekitem()[1])
    # 中判断某元素是否存在
    print("wxx" in sd)  # True
    # bisect_left() / bisect_right()
    sd["a"] = 1
    sd["c1"] = 2
    sd["c2"] = 4
    print(
        sd
    )  # SortedDict({'a': 1, 'c1': 2, 'c2': 4, 'hh': 18, 'other': 20, 'wxx': 21})
    print(sd.bisect_left("c1"))  # 返回键大于等于"c1"的最小元素对应的下标    1
    print(sd.bisect_right("c1"))  # 返回键大于"c1"的最小元素对应的下标    2
    # 清空
    sd.clear()
    print(len(sd))  # 0
    print(len(sd) == 0)  # True
    """
    无序的map: dict
    """
    print("---------------------------------------")
    d = {"c1": 2, "c2": 4, "hh": 18, "wxx": 21, 13: 14, 1: 0}
    print(d["wxx"])  # 21
    print(d[13])  # 14
    d[13] += 1
    print(d[13])  # 15
    d["future"] = "wonderful"  # 字典中添加键值对
    del d[1]  # 删除字典d中键1对应的数据值
    print("wxx" in d)  # 判断键"wxx"是否在字典d中,如果在返回True,否则False
    print(d.keys())  # 返回字典d中所有的键信息  dict_keys(['c1', 'c2', 'hh', 'wxx', 13])
    print(d.values())  # 返回字典d中所有的值信息  dict_values([2, 4, 18, 21, 14])
    print(d.items(
    ))  # dict_items([('c1', 2), ('c2', 4), ('hh', 18), ('wxx', 21), (13, 14)])
    for k, v in d.items():  # 遍历 k, v
        print(k, ':', v)
    for k in d:  # 遍历键k, 等价于for k in d.keys:
        print(str(k) + ":" + str(d[k]), end=", ")
    print()
    for v in d.values():  # 遍历值v
        print(v, end=", ")
    print()
    # 字典类型操作函数和方法
    print("---------------------------------------")
    d = {"中国": "北京", "美国": "华盛顿", "法国": "巴黎"}
    print(len(d))  # 返回字典d中元素的个数  3
    print(d.get("中国", "不存在"))  # 键k存在,则返回相应值,不在则返回<default>值  北京
    print(d.get("中", "不存在"))  # 不存在
    print(d.get("中"))  # None
    d["美国"] = "Washington"  # 修改键对应的值
    print(d.pop("美国"))  # 键k存在,则返回相应值,并将其从dict中删除
    print(d.popitem())  # 随机从字典d中取出一个键值对,以元组形式返回,并将其从dict中删除
    d.clear()  # 删除所有的键值对
示例#33
0
class FederationRemoteSendQueue(object):
    """A drop in replacement for FederationSender"""

    def __init__(self, hs):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
        self.presence_changed = SortedDict()  # Stream position -> list[user_id]

        # Stores the destinations we need to explicitly send presence to about a
        # given user.
        # Stream position -> (user_id, destinations)
        self.presence_destinations = SortedDict()

        self.keyed_edu = {}  # (destination, key) -> EDU
        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)

        self.edus = SortedDict()  # stream position -> Edu

        self.device_messages = SortedDict()  # stream position -> destination

        self.pos = 1
        self.pos_time = SortedDict()

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name, queue):
            LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
                       "", [], lambda: len(queue))

        for queue_name in [
            "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
            "edus", "device_messages", "pos_time", "presence_destinations",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self):
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self):
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete):
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_changed.keys()
            i = self.presence_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_changed[key]

            user_ids = set(
                user_id
                for uids in self.presence_changed.values()
                for user_id in uids
            )

            keys = self.presence_destinations.keys()
            i = self.presence_destinations.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_destinations[key]

            user_ids.update(
                user_id for user_id, _ in self.presence_destinations.values()
            )

            to_del = [
                user_id for user_id in self.presence_map if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
            for edu_key in to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

            # Delete things out of device map
            keys = self.device_messages.keys()
            i = self.device_messages.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.device_messages[key]

    def notify_new_events(self, current_id):
        """As per FederationSender"""
        # We don't need to replicate this as it gets sent down a different
        # stream.
        pass

    def build_and_send_edu(self, destination, edu_type, content, key=None):
        """As per FederationSender"""
        if destination == self.server_name:
            logger.info("Not sending EDU to ourselves")
            return

        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    def send_read_receipt(self, receipt):
        """As per FederationSender

        Args:
            receipt (synapse.types.ReadReceipt):
        """
        # nothing to do here: the replication listener will handle it.
        pass

    def send_presence(self, states):
        """As per FederationSender

        Args:
            states (list(UserPresenceState))
        """
        pos = self._next_pos()

        # We only want to send presence for our own users, so lets always just
        # filter here just in case.
        local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))

        self.presence_map.update({state.user_id: state for state in local_states})
        self.presence_changed[pos] = [state.user_id for state in local_states]

        self.notifier.on_new_replication_data()

    def send_presence_to_destinations(self, states, destinations):
        """As per FederationSender

        Args:
            states (list[UserPresenceState])
            destinations (list[str])
        """
        for state in states:
            pos = self._next_pos()
            self.presence_map.update({state.user_id: state for state in states})
            self.presence_destinations[pos] = (state.user_id, destinations)

        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination):
        """As per FederationSender"""
        pos = self._next_pos()
        self.device_messages[pos] = destination
        self.notifier.on_new_replication_data()

    def get_current_token(self):
        return self.pos - 1

    def federation_ack(self, token):
        self._clear_queue_before_pos(token)

    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
        """Get rows to be sent over federation between the two tokens

        Args:
            from_token (int)
            to_token(int)
            limit (int)
            federation_ack (int): Optional. The position where the worker is
                explicitly acknowledged it has handled. Allows us to drop
                data from before that point
        """
        # TODO: Handle limit.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []

        # There should be only one reader, so lets delete everything its
        # acknowledged its seen.
        if federation_ack:
            self._clear_queue_before_pos(federation_ack)

        # Fetch changed presence
        i = self.presence_changed.bisect_right(from_token)
        j = self.presence_changed.bisect_right(to_token) + 1
        dest_user_ids = [
            (pos, user_id)
            for pos, user_id_list in self.presence_changed.items()[i:j]
            for user_id in user_id_list
        ]

        for (key, user_id) in dest_user_ids:
            rows.append((key, PresenceRow(
                state=self.presence_map[user_id],
            )))

        # Fetch presence to send to destinations
        i = self.presence_destinations.bisect_right(from_token)
        j = self.presence_destinations.bisect_right(to_token) + 1

        for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
            rows.append((pos, PresenceDestinationsRow(
                state=self.presence_map[user_id],
                destinations=list(dests),
            )))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in iteritems(keyed_edus):
            rows.append((pos, KeyedEduRow(
                key=edu_key,
                edu=self.keyed_edu[(destination, edu_key)],
            )))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Fetch changed device messages
        i = self.device_messages.bisect_right(from_token)
        j = self.device_messages.bisect_right(to_token) + 1
        device_messages = {v: k for k, v in self.device_messages.items()[i:j]}

        for (destination, pos) in iteritems(device_messages):
            rows.append((pos, DeviceRow(
                destination=destination,
            )))

        # Sort rows based on pos
        rows.sort()

        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
示例#34
0
class StreamChangeCache(object):
    """Keeps track of the stream positions of the latest change in a set of entities.

    Typically the entity will be a room or user id.

    Given a list of entities and a stream position, it will give a subset of
    entities that may have changed since that position. If position key is too
    old then the cache will simply return all given entities.
    """
    def __init__(self,
                 name,
                 current_stream_pos,
                 max_size=10000,
                 prefilled_cache=None):
        self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
        self._entity_to_key = {}
        self._cache = SortedDict()
        self._earliest_known_stream_pos = current_stream_pos
        self.name = name
        self.metrics = caches.register_cache("cache", self.name, self._cache)

        if prefilled_cache:
            for entity, stream_pos in prefilled_cache.items():
                self.entity_has_changed(entity, stream_pos)

    def has_entity_changed(self, entity, stream_pos):
        """Returns True if the entity may have been updated since stream_pos
        """
        assert type(stream_pos) is int or type(stream_pos) is long

        if stream_pos < self._earliest_known_stream_pos:
            self.metrics.inc_misses()
            return True

        latest_entity_change_pos = self._entity_to_key.get(entity, None)
        if latest_entity_change_pos is None:
            self.metrics.inc_hits()
            return False

        if stream_pos < latest_entity_change_pos:
            self.metrics.inc_misses()
            return True

        self.metrics.inc_hits()
        return False

    def get_entities_changed(self, entities, stream_pos):
        """
        Returns subset of entities that have had new things since the given
        position.  Entities unknown to the cache will be returned.  If the
        position is too old it will just return the given list.
        """
        assert type(stream_pos) is int

        if stream_pos >= self._earliest_known_stream_pos:
            not_known_entities = set(entities) - set(self._entity_to_key)

            result = ({
                self._cache[k]
                for k in self._cache.islice(
                    start=self._cache.bisect_right(stream_pos))
            }.intersection(entities).union(not_known_entities))

            self.metrics.inc_hits()
        else:
            result = set(entities)
            self.metrics.inc_misses()

        return result

    def has_any_entity_changed(self, stream_pos):
        """Returns if any entity has changed
        """
        assert type(stream_pos) is int

        if not self._cache:
            # If we have no cache, nothing can have changed.
            return False

        if stream_pos >= self._earliest_known_stream_pos:
            self.metrics.inc_hits()
            return self._cache.bisect_right(stream_pos) < len(self._cache)
        else:
            self.metrics.inc_misses()
            return True

    def get_all_entities_changed(self, stream_pos):
        """Returns all entites that have had new things since the given
        position. If the position is too old it will return None.
        """
        assert type(stream_pos) is int

        if stream_pos >= self._earliest_known_stream_pos:
            return [
                self._cache[k] for k in self._cache.islice(
                    start=self._cache.bisect_right(stream_pos))
            ]
        else:
            return None

    def entity_has_changed(self, entity, stream_pos):
        """Informs the cache that the entity has been changed at the given
        position.
        """
        assert type(stream_pos) is int

        if stream_pos > self._earliest_known_stream_pos:
            old_pos = self._entity_to_key.get(entity, None)
            if old_pos is not None:
                stream_pos = max(stream_pos, old_pos)
                self._cache.pop(old_pos, None)
            self._cache[stream_pos] = entity
            self._entity_to_key[entity] = stream_pos

            while len(self._cache) > self._max_size:
                k, r = self._cache.popitem(0)
                self._earliest_known_stream_pos = max(
                    k,
                    self._earliest_known_stream_pos,
                )
                self._entity_to_key.pop(r, None)

    def get_max_pos_of_last_change(self, entity):
        """Returns an upper bound of the stream id of the last change to an
        entity.
        """
        return self._entity_to_key.get(entity, self._earliest_known_stream_pos)