Example #1
0
class SortedSetKey:
    def __init__(self):
        self.dict = dict()
        self.sorted_set = SortedSet(key=self.get_key)

    def __getitem__(self, item):
        return self.sorted_set[item]

    def __len__(self):
        return len(self.sorted_set)

    def __str__(self):
        return str(self.sorted_set)

    def get_key(self, value):
        return self.dict[value]

    def get_reversed_list(self, index, count):
        return self[-1 - index:-1 - index - count:-1]

    def values(self):
        for value in self.sorted_set:
            yield value

    def clear(self):
        self.sorted_set.clear()
        self.dict.clear()

    def destroy(self):
        self.sorted_set = None

    def index(self, value):
        return self.sorted_set.index(value)

    def pop(self, index=-1):
        return self.sorted_set.pop(index)

    def add(self, value, rank):
        if value in self.sorted_set:
            self.sorted_set.remove(value)
        self.dict[value] = rank
        self.sorted_set.add(value)

    def remove(self, value):
        self.sorted_set.remove(value)
        del self.dict[value]

    def update(self, value_list, rank_list):
        self.sorted_set.difference_update(value_list)
        for i, value in enumerate(value_list):
            self.dict[value] = rank_list[i]
        self.sorted_set.update(value_list)
def test_init():
    sst = SortedSet()
    sst._check()

    sst = SortedSet()
    sst._reset(10000)
    assert sst._list._load == 10000
    sst._check()

    sst = SortedSet(range(10000))
    assert all(tup[0] == tup[1] for tup in zip(sst, range(10000)))

    sst.clear()
    assert len(sst) == 0
    assert list(iter(sst)) == []
    sst._check()
Example #3
0
def test_init():
    sst = SortedSet()
    sst._check()

    sst = SortedSet()
    sst._reset(10000)
    assert sst._list._load == 10000
    sst._check()

    sst = SortedSet(range(10000))
    assert all(tup[0] == tup[1] for tup in zip(sst, range(10000)))

    sst.clear()
    assert len(sst) == 0
    assert list(iter(sst)) == []
    sst._check()
class PlatformBatchLightSystem:
    """Batch light system for platforms."""

    __slots__ = [
        "dirty_lights", "dirty_schedule", "clock", "is_sequential_function",
        "update_task", "update_callback", "sort_function", "update_hz",
        "max_batch_size"
    ]

    # pylint: disable-msg=too-many-arguments
    def __init__(self, clock, sort_function, is_sequential_function,
                 update_callback, update_hz, max_batch_size):
        """Initialise light system."""
        self.dirty_lights = SortedSet(
            key=sort_function)  # type: Set[PlatformBatchLight]
        self.dirty_schedule = SortedList(
            key=lambda x: x[0] + sort_function(x[1]))
        self.is_sequential_function = is_sequential_function
        self.sort_function = sort_function
        self.update_task = None
        self.clock = clock
        self.update_callback = update_callback
        self.update_hz = update_hz
        self.max_batch_size = max_batch_size

    def start(self):
        """Start light system."""
        self.update_task = self.clock.loop.create_task(self._send_updates())
        self.update_task.add_done_callback(self._done)

    def stop(self):
        """Stop light system."""
        if self.update_task:
            self.update_task.cancel()
            self.update_task = None

    @staticmethod
    def _done(future):
        try:
            future.result()
        except asyncio.CancelledError:
            pass

    async def _send_updates(self):
        while True:
            while self.dirty_schedule and self.dirty_schedule[0][
                    0] <= self.clock.get_time():
                self.dirty_lights.add(self.dirty_schedule[0][1])
                del self.dirty_schedule[0]

            sequential_lights = []
            for light in list(self.dirty_lights):
                if not sequential_lights:
                    # first light
                    sequential_lights = [light]
                elif self.is_sequential_function(sequential_lights[-1], light):
                    # lights are sequential
                    sequential_lights.append(light)
                else:
                    # sequence ended
                    await self._send_update_batch(sequential_lights)
                    # this light is a new sequence
                    sequential_lights = [light]

            if sequential_lights:
                await self._send_update_batch(sequential_lights)

            self.dirty_lights.clear()

            await asyncio.sleep(.001, loop=self.clock.loop)

    async def _send_update_batch(self, sequential_lights):
        sequential_brightness_list = []
        common_fade_ms = None
        current_time = self.clock.get_time()
        for light in sequential_lights:
            brightness, fade_ms, done = light.get_fade_and_brightness(
                current_time)
            if not done:
                self.dirty_schedule.add(
                    (current_time + (fade_ms / 1000), light))
            if common_fade_ms is None:
                common_fade_ms = fade_ms

            if common_fade_ms == fade_ms and len(
                    sequential_brightness_list) < self.max_batch_size:
                sequential_brightness_list.append(
                    (light, brightness, common_fade_ms))
            else:
                await self.update_callback(sequential_brightness_list)
                # start new list
                current_time = self.clock.get_time()
                common_fade_ms = fade_ms
                sequential_brightness_list = [(light, brightness,
                                               common_fade_ms)]

        if sequential_brightness_list:
            await self.update_callback(sequential_brightness_list)

    def mark_dirty(self, light: "PlatformBatchLight"):
        """Mark as dirty."""
        self.dirty_lights.add(light)
        self.dirty_schedule = SortedList(
            [x for x in self.dirty_schedule if x[1] != light],
            key=lambda x: x[0] + self.sort_function(x[1]))
def test_clear():
    temp = SortedSet(range(100), load=7)
    temp.clear()
    temp._check()
    assert len(temp) == 0
Example #6
0
class Chunk(object):
    """
    Represents a chunk of code providing some useful functionality in the system.
    """

    def __init__(self, logical_name, feature, local_content=None):
        self.logical_name = logical_name

        self.feature = feature

        self.local_content = local_content

        self.dependencies = SortedSet(key=lambda d: d.fully_qualified_name)
        self.bugs = SortedSet(key=lambda b: b.logical_name)

        self.bug_count = 0

    def __eq__(self, other):
        if self.local_content != other.local_content:
            return False
        elif self.bugs_logical_names != other.bugs_logical_names:
            return False
        elif self.dependency_logical_names != other.dependency_logical_names:
            return False
        else:
            return True

    def __ne__(self, other):
        return not(self.__eq__(other))

    @property
    def probability_gain_feature_dependency(self):
        return self.feature.software_system.probability_gain_feature_dependency

    @property
    def probability_lose_feature_dependency(self):
        return self.feature.software_system.probability_lose_feature_dependency

    @property
    def probability_gain_system_dependency(self):
        return self.feature.software_system.probability_gain_system_dependency

    @property
    def probability_lose_system_dependency(self):
        return self.feature.software_system.probability_lose_system_dependency

    @property
    def probability_new_bug(self):
        return self.feature.software_system.probability_new_bug

    @property
    def probability_debug_known(self):
        return self.feature.software_system.probability_debug_known

    @property
    def probability_debug_unknown(self):
        return self.feature.software_system.probability_debug_unknown

    @property
    def dependency_logical_names(self):
        return map(lambda d: d.logical_name, self.dependencies)

    @property
    def bugs_logical_names(self):
        return map(lambda b: b.logical_name, self.bugs)

    @property
    def bugs_in_dependencies(self):
        chunk_bug_set = frozenset(map(lambda chunk: frozenset(chunk.bugs), self.dependencies))
        return reduce(lambda bugs_a, bugs_b: bugs_a.union(bugs_b), chunk_bug_set, set())

    @property
    def tests(self):
        return filter(lambda t: self in t.chunks, self.feature.tests)

    def modify(self, random):
        feature_chunks = self.feature.chunks - {self}
        system_chunks = set(self.feature.software_system.chunks.difference(self.feature.chunks))
        self._add_dependencies(random, system_chunks, self.probability_gain_system_dependency)
        self._add_dependencies(random, feature_chunks, self.probability_gain_feature_dependency)

        self.local_content = random.create_local_content()

        self._insert_bugs(random)

    def merge(self, source_chunk, random):
        for dependency in source_chunk.dependencies:
            working_copy_dependency = self.feature.software_system.get_chunk(dependency.fully_qualified_name)
            self.dependencies.add(working_copy_dependency)

        self.modify(random)

    def overwrite_with(self, source_chunk):

        self.local_content = source_chunk.local_content

        self.bugs.clear()
        for old_bug in source_chunk.bugs:
            new_bug = self.get_bug(old_bug.logical_name)
            if new_bug is None:
                self.add_bug(old_bug.logical_name)

        self.dependencies.clear()
        for dependency in source_chunk.dependencies:
            new_dependency = self.feature.software_system.get_chunk(dependency.fully_qualified_name)
            self.dependencies.add(new_dependency)

    def _add_dependencies(self, random, candidate_chunks, threshold):
        for candidate in SortedSet(candidate_chunks, key=lambda c: c.logical_name):
            if random.dependency_should_be_added(threshold):
                self.add_dependency(candidate)

    def add_dependency(self, candidate):
        self.dependencies.add(candidate)

    def _insert_bugs(self, random):
        while random.a_bug_should_be_inserted(self):
            self.add_bug(self.bug_count)
            self.bug_count += 1

    def add_bug(self, logical_name):
        self.bugs.add(Bug(logical_name, self))

    def get_bug(self, logical_name):
        result = filter(lambda bug: bug.logical_name == logical_name, self.bugs)
        if len(result) is 0:
            return None
        else:
            return result[0]

    def refactor(self, random):
        to_remove = set()
        for dependency in self.dependencies:

            if random.dependency_should_be_removed(self, dependency):
                to_remove.add(dependency)

        self.dependencies.difference_update(to_remove)

    def debug(self, random, bug=None):

        if len(self.bugs) == 0:
            return False

        if bug is None or bug not in self.bugs:
            if random.unknown_bug_should_be_removed(self):
                bug = random.choose_bug(self)
                self.bugs.remove(bug)
        elif random.known_bug_should_be_removed(self):
            self.bugs.remove(bug)

    def operate(self, random):
        for bug in self.bugs_in_dependencies.union(self.bugs):
            bug.manifest(random)

    def __str__(self):
        def string_repr_set(iterable):
            return ",".join(map(lambda e: repr(e), iterable))

        feature_dependencies = string_repr_set(filter(lambda c: c.feature == self.feature, self.dependencies))
        system_dependencies = string_repr_set(filter(lambda c: c.feature != self.feature, self.dependencies))

        bugs = ", ".join(map(lambda bug: str(bug), self.bugs))

        return "c_%s:[%s]:[%s]->(in[%s],ex[%s])" % \
               (str(self.logical_name), self.local_content, bugs, feature_dependencies, system_dependencies)

    @property
    def fully_qualified_name(self):
        return "%s.%s" % (str(self.feature.logical_name), str(self.logical_name))

    def __repr__(self):
        return "c%s" % str(self.fully_qualified_name)
Example #7
0
def test_clear():
    temp = SortedSet(range(100))
    temp._reset(7)
    temp.clear()
    temp._check()
    assert len(temp) == 0
Example #8
0
def test5():
    """
    有序的集合:SortedSet
    网址:http://www.grantjenks.com/docs/sortedcontainers/sortedset.html
    """
    from sortedcontainers import SortedSet
    # 创建 SortedSet
    ss = SortedSet([3, 1, 2, 5, 4])
    print(ss)  # SortedSet([1, 2, 3, 4, 5])
    from operator import neg
    ss1 = SortedSet([3, 1, 2, 5, 4], neg)
    print(ss1)  # SortedSet([5, 4, 3, 2, 1], key=<built-in function neg>)
    # SortedSet 转为 list/tuple/set
    print(list(ss))  # SortedSet转为list    [1, 2, 3, 4, 5]
    print(tuple(ss))  # SortedSet转为tuple    (1, 2, 3, 4, 5)
    print(set(ss))  # SortedSet转为set    {1, 2, 3, 4, 5}
    # 插入、删除元素
    ss.discard(-1)  # 删除不存在的元素不报错
    ss.remove(1)  # 删除不存在的元素报错, KeyError
    ss.discard(3)  # SortedSet([1, 2, 4, 5])
    ss.add(-10)  # SortedSet([-10, 1, 2, 4, 5])
    # 返回第一个和最后一个元素
    print(ss[0])  # -10
    print(ss[-1])  # 5
    # 遍历 set
    for e in ss:
        print(e, end=", ")  # -10, 2, 4, 5,
    print()
    # set 中判断某元素是否存在
    print(2 in ss)  # True
    # bisect_left() / bisect_right()
    print(ss.bisect_left(4))  # 返回大于等于4的最小元素对应的下标    2
    print(ss.bisect_right(4))  # 返回大于4的最小元素对应的下标    3
    # 清空 set
    ss.clear()
    print(len(ss))  # 0
    print(len(ss) == 0)  # True
    """
    无序的集合: set
    """
    # 集合的定义:集合是不可变的,因此集合中元素不能是list
    A = {"hi", 2, ("we", 24)}
    B = set()  # 空集合的定义,不能使用B = {}定义集合,这样是字典的定义
    # 集合间的操作, 下面的运算法符都可以写成 op= 的形式
    print("---------------------------------------")
    S = {1, 2, 3}
    T = {3, 4, 5}
    print(S & T)  # 交集,返回一个新集合,包括同时在集合S和T中的元素
    print(S | T)  # 并集,返回一个新集合,包括在集合S和T中的所有元素
    print(S - T)  # 差集,返回一个新集合,包括在集合S但不在T中的元素
    print(S ^ T)  # 补集,返回一个新集合,包括集合S和T中的非相同元素
    # 集合的包含关系
    print("---------------------------------------")
    C = {1, 2}
    D = {1, 2}
    print(C <= D)  # C是否是D的子集  True
    print(C < D)  # C是否是D的真子集  False
    print(C >= D)  # D是否是C的子集  True
    print(C > D)  # D是否是C的真子集  False
    # 集合的处理方法
    print("---------------------------------------")
    S = {1, 2, 3, 5, 6}
    S.add(4)  # 如果x不在集合S中,将x增加到S
    S.discard(1)  # 移除S中元素x,如果x不在集合S中,不报错
    S.remove(2)  # 移除S中元素x,如果x不在集合S中,产生KeyError异常
    for e in S:  # 遍历
        print(e, end=",")
    print()
    print(S.pop())  # 从S中随机弹出一个元素,S长度减1,若S为空产生KeyError异常
    print(S.copy())  # 返回集合S的一个副本, 对该副本的操作不会影响S
    print(len(S))  # 返回集合S的元素个数
    print(5 in S)  # 判断S中元素x, x在集合S中,返回True,否则返回False
    print(5 not in S)  # 判断S中元素x, x在集合S中,返回True,否则返回False
    S.clear()  # 移除S中所有元素
class PlatformBatchLightSystem:

    """Batch light system for platforms."""

    __slots__ = ["dirty_lights", "dirty_schedule", "clock", "update_task", "update_callback",
                 "update_hz", "max_batch_size", "scheduler_task", "schedule_changed", "dirty_lights_changed",
                 "last_state"]

    # pylint: disable-msg=too-many-arguments
    def __init__(self, clock, update_callback, update_hz, max_batch_size):
        """Initialise light system."""
        self.dirty_lights = SortedSet()    # type: Set[PlatformBatchLight]
        self.dirty_lights_changed = asyncio.Event()
        self.dirty_schedule = SortedList()
        self.schedule_changed = asyncio.Event()
        self.update_task = None
        self.scheduler_task = None
        self.clock = clock
        self.update_callback = update_callback
        self.update_hz = update_hz
        self.max_batch_size = max_batch_size
        self.last_state = {}

    def start(self):
        """Start light system."""
        self.update_task = self.clock.loop.create_task(self._send_updates())
        self.update_task.add_done_callback(Util.raise_exceptions)
        self.scheduler_task = self.clock.loop.create_task(self._schedule_updates())
        self.scheduler_task.add_done_callback(Util.raise_exceptions)

    def stop(self):
        """Stop light system."""
        if self.scheduler_task:
            self.scheduler_task.cancel()
            self.scheduler_task = None
        if self.update_task:
            self.update_task.cancel()
            self.update_task = None

    async def _schedule_updates(self):
        while True:
            run_time = self.clock.get_time()
            self.schedule_changed.clear()
            while self.dirty_schedule and self.dirty_schedule[0][0] <= run_time:
                self.dirty_lights.add(self.dirty_schedule[0][1])
                del self.dirty_schedule[0]
            self.dirty_lights_changed.set()

            if self.dirty_schedule:
                try:
                    await asyncio.wait_for(self.schedule_changed.wait(), self.dirty_schedule[0][0] - run_time)
                except asyncio.TimeoutError:
                    pass
            else:
                await self.schedule_changed.wait()

    async def _send_updates(self):
        poll_sleep_time = 1 / self.update_hz
        max_fade_tolerance = int(poll_sleep_time * 1000)
        while True:
            await self.dirty_lights_changed.wait()
            self.dirty_lights_changed.clear()
            sequential_lights = []
            for light in list(self.dirty_lights):
                if not sequential_lights:
                    # first light
                    sequential_lights = [light]
                elif light.is_successor_of(sequential_lights[-1]):
                    # lights are sequential
                    sequential_lights.append(light)
                else:
                    # sequence ended
                    await self._send_update_batch(sequential_lights, max_fade_tolerance)
                    # this light is a new sequence
                    sequential_lights = [light]

            if sequential_lights:
                await self._send_update_batch(sequential_lights, max_fade_tolerance)

            self.dirty_lights.clear()

            await asyncio.sleep(poll_sleep_time)

    async def _send_update_batch(self, sequential_lights: List[PlatformBatchLight], max_fade_tolerance):
        sequential_brightness_list = []     # type: List[Tuple[LightPlatformInterface, float, int]]
        common_fade_ms = None
        current_time = self.clock.get_time()
        for light in sequential_lights:
            brightness, fade_ms, done = light.get_fade_and_brightness(current_time)
            schedule_time = current_time + (fade_ms / 1000)
            if not done:
                if not self.dirty_schedule or self.dirty_schedule[0][0] > schedule_time:
                    self.schedule_changed.set()
                self.dirty_schedule.add((schedule_time, light))
            else:
                # check if we realized this brightness earlier
                last_state = self.last_state.get(light, None)
                if last_state and last_state[0] == brightness and last_state[1] < schedule_time and \
                        not sequential_brightness_list:
                    # we already set the light to that color earlier. skip it
                    # we only skip this light if we are in the beginning of the list for now
                    # the reason for that is that we do not want to break fade chains when one color channel
                    # of an RGB light did not change
                    # this could become an option in the future
                    continue

            self.last_state[light] = (brightness, schedule_time)

            if common_fade_ms is None:
                common_fade_ms = fade_ms

            if -max_fade_tolerance < common_fade_ms - fade_ms < max_fade_tolerance and \
                    len(sequential_brightness_list) < self.max_batch_size:
                sequential_brightness_list.append((light, brightness, common_fade_ms))
            else:
                await self.update_callback(sequential_brightness_list)
                # start new list
                current_time = self.clock.get_time()
                common_fade_ms = fade_ms
                sequential_brightness_list = [(light, brightness, common_fade_ms)]

        if sequential_brightness_list:
            await self.update_callback(sequential_brightness_list)

    def mark_dirty(self, light: "PlatformBatchLight"):
        """Mark as dirty."""
        self.dirty_lights.add(light)
        self.dirty_lights_changed.set()
        self.dirty_schedule = SortedList([x for x in self.dirty_schedule if x[1] != light])
Example #10
0
class PlatformBatchLightSystem:
    """Batch light system for platforms."""

    __slots__ = [
        "dirty_lights", "dirty_schedule", "clock", "update_task",
        "update_callback", "update_hz", "max_batch_size", "scheduler_task",
        "schedule_changed", "dirty_lights_changed"
    ]

    # pylint: disable-msg=too-many-arguments
    def __init__(self, clock, update_callback, update_hz, max_batch_size):
        """Initialise light system."""
        self.dirty_lights = SortedSet()  # type: Set[PlatformBatchLight]
        self.dirty_lights_changed = asyncio.Event(loop=clock.loop)
        self.dirty_schedule = SortedList()
        self.schedule_changed = asyncio.Event(loop=clock.loop)
        self.update_task = None
        self.scheduler_task = None
        self.clock = clock
        self.update_callback = update_callback
        self.update_hz = update_hz
        self.max_batch_size = max_batch_size

    def start(self):
        """Start light system."""
        self.update_task = self.clock.loop.create_task(self._send_updates())
        self.update_task.add_done_callback(Util.raise_exceptions)
        self.scheduler_task = self.clock.loop.create_task(
            self._schedule_updates())
        self.scheduler_task.add_done_callback(Util.raise_exceptions)

    def stop(self):
        """Stop light system."""
        if self.scheduler_task:
            self.scheduler_task.cancel()
            self.scheduler_task = None
        if self.update_task:
            self.update_task.cancel()
            self.update_task = None

    async def _schedule_updates(self):
        while True:
            run_time = self.clock.get_time()
            self.schedule_changed.clear()
            while self.dirty_schedule and self.dirty_schedule[0][0] <= run_time:
                self.dirty_lights.add(self.dirty_schedule[0][1])
                del self.dirty_schedule[0]
            self.dirty_lights_changed.set()

            if self.dirty_schedule:
                await asyncio.wait([self.schedule_changed.wait()],
                                   loop=self.clock.loop,
                                   timeout=self.dirty_schedule[0][0] -
                                   run_time,
                                   return_when=asyncio.FIRST_COMPLETED)
            else:
                await self.schedule_changed.wait()

    async def _send_updates(self):
        poll_sleep_time = 1 / self.update_hz
        max_fade_tolerance = int(poll_sleep_time * 1000)
        while True:
            await self.dirty_lights_changed.wait()
            self.dirty_lights_changed.clear()
            sequential_lights = []
            for light in list(self.dirty_lights):
                if not sequential_lights:
                    # first light
                    sequential_lights = [light]
                elif light.is_successor_of(sequential_lights[-1]):
                    # lights are sequential
                    sequential_lights.append(light)
                else:
                    # sequence ended
                    await self._send_update_batch(sequential_lights,
                                                  max_fade_tolerance)
                    # this light is a new sequence
                    sequential_lights = [light]

            if sequential_lights:
                await self._send_update_batch(sequential_lights,
                                              max_fade_tolerance)

            self.dirty_lights.clear()

            await asyncio.sleep(poll_sleep_time, loop=self.clock.loop)

    async def _send_update_batch(self,
                                 sequential_lights: List[PlatformBatchLight],
                                 max_fade_tolerance):
        sequential_brightness_list = [
        ]  # type: List[Tuple[LightPlatformInterface, float, int]]
        common_fade_ms = None
        current_time = self.clock.get_time()
        for light in sequential_lights:
            brightness, fade_ms, done = light.get_fade_and_brightness(
                current_time)
            if not done:
                schedule_time = current_time + (fade_ms / 1000)
                if not self.dirty_schedule or self.dirty_schedule[0][
                        0] > schedule_time:
                    self.schedule_changed.set()
                self.dirty_schedule.add((schedule_time, light))
            if common_fade_ms is None:
                common_fade_ms = fade_ms

            if -max_fade_tolerance < common_fade_ms - fade_ms < max_fade_tolerance and \
                    len(sequential_brightness_list) < self.max_batch_size:
                sequential_brightness_list.append(
                    (light, brightness, common_fade_ms))
            else:
                await self.update_callback(sequential_brightness_list)
                # start new list
                current_time = self.clock.get_time()
                common_fade_ms = fade_ms
                sequential_brightness_list = [(light, brightness,
                                               common_fade_ms)]

        if sequential_brightness_list:
            await self.update_callback(sequential_brightness_list)

    def mark_dirty(self, light: "PlatformBatchLight"):
        """Mark as dirty."""
        self.dirty_lights.add(light)
        self.dirty_lights_changed.set()
        self.dirty_schedule = SortedList(
            [x for x in self.dirty_schedule if x[1] != light])
Example #11
0
def main():
    n = 50000
    d = 10
    #max_level = floor(sqrt(n) * sqrt(log(n))) #=735 for n=50000
    max_level = 6
    n_nbrs = 11
    n_rndms = 0
    source_node = -1
    n_epochs = 1
    batch_size = 200

    #placeholder for input
    latentPoints = normalize(np.random.normal(0, 1, (n, d)))
    #latentPoints = normalize(np.load('latent_points_10.npy')[:n])

    start1 = time.clock()
    #---before training---
    (G, client_nodes, server_nodes, annoy_index,
     targetPoints) = createGraph(n=n,
                                 d=10,
                                 latentPoints=latentPoints,
                                 n_trees=60,
                                 n_nbrs=n_nbrs,
                                 n_rndms=n_rndms)
    end1 = time.clock()
    print('Created G bipartate graph. Elapsed time: ', end1 - start1)

    #initialize H before the clients are being added iteratively
    H = nx.DiGraph()
    #initialize M digraph of the matching directed from C to S
    M = nx.DiGraph()
    M.add_nodes_from(client_nodes)
    M.add_nodes_from(server_nodes)
    #initialize F bipartate graph of forces
    #F = nx.DiGraph()
    #initialize levels
    levels = {}
    best_gains = {}
    free_gains = SortedSet(key=lambda x: x[1])
    #initialize parents, (node, level): SortedSet(parents of the node on the given level)
    parents_by_level = {}
    #for l in range(max_level + 2):
    #    parents_by_level[(source_node, l)] = SortedSet(key=lambda x: x[1])
    for c in client_nodes:
        for l in range(max_level + 2):
            parents_by_level[(c, l)] = SortedSet(key=lambda x: x[1])
    for s in server_nodes:
        for l in range(max_level + 2):
            parents_by_level[(s, l)] = SortedSet(key=lambda x: x[1])

    for e in range(n_epochs):  #simulate epochs
        start2 = time.clock()
        initializeESGraph(H,
                          parents_by_level,
                          levels,
                          best_gains,
                          client_nodes,
                          server_nodes,
                          source_node=source_node)
        for i in range(n //
                       batch_size):  #simulate training by batches on an epoch
            batch_indices = SortedSet(
                range(i * batch_size, (i + 1) * batch_size))  #placeholder
            #latentBatch = normalize(np.random.normal(0, 1, size=(batch_size, 10))) #placeholder
            latentBatch = latentPoints[i * batch_size:(i + 1) *
                                       batch_size]  # placeholder
            addBatch(G,
                     n,
                     annoy_index,
                     batch_indices,
                     latentBatch,
                     targetPoints,
                     H,
                     parents_by_level,
                     levels,
                     best_gains,
                     M,
                     max_level,
                     n_nbrs=n_nbrs,
                     n_rndms=n_rndms)
        end2 = time.clock()

        print('number of matching edges / n : ', len(M.edges()) / n)

        #deepcopy M to F
        #F.add_nodes_from(M.nodes)
        #F.add_edges_from(M.edges)
        print('Created the initial ES graph. Elapsed time: ', end2 - start2)

        weight_of_matching = 0
        for (u, v, wt) in M.edges.data('weight'):
            weight_of_matching = weight_of_matching + wt
        print('Weight of the found matching: ', weight_of_matching)
        print('Average weight of an edge in the matching: ',
              weight_of_matching / len(M.edges()))

        improvable = True
        while improvable:
            for s in server_nodes:
                if levels[s] > 1 and levels[s] <= max_level:
                    #print('s: ', s)
                    #print(levels[s])
                    #print((s, parents_by_level[(s, levels[s] - 1)][0][1]))
                    free_gains.add(
                        (s, parents_by_level[(s, levels[s] - 1)][0][1]))
            #print(len(free_gains))
            print(free_gains[0])
            if free_gains[0][1] >= 0:
                improvable = False
            else:
                (path, found_short_path) = ES.findSAP(H, parents_by_level,
                                                      levels, best_gains,
                                                      free_gains[0][0])
                ES.applyServerPath(H, parents_by_level, levels, best_gains, M,
                                   path, max_level, found_short_path,
                                   source_node)
            free_gains.clear()

        weight_of_matching = 0
        for (u, v, wt) in M.edges.data('weight'):
            weight_of_matching = weight_of_matching + wt
        print('Weight of the improved matching: ', weight_of_matching)
        print('Average weight of an edge in the matching: ',
              weight_of_matching / len(M.edges()))

        #reset for the next epoch
        G.clear()
        H.clear()
        M.clear()
        M.add_nodes_from(client_nodes)
        M.add_nodes_from(server_nodes)

        #for l in range(max_level + 2):
        #    parents_by_level[(source_node, l)].clear()
        for c in client_nodes:
            for l in range(max_level + 2):
                parents_by_level[(c, l)].clear()
        for s in server_nodes:
            for l in range(max_level + 2):
                parents_by_level[(s, l)].clear()

        #placeholder for latent points
        latentPoints = normalize(np.random.normal(0, 1, (n, d)))
Example #12
0
class Display:
    def __init__(self, interface, dimensions):
        self.logger = logging.getLogger(__name__)

        self.interface = interface
        self.dimensions = dimensions

        (rows, columns) = self.dimensions

        self.buffer = bytearray(rows * columns)
        self.dirty = SortedSet()

        self.address_counter = None

        self.status_line = StatusLine(self)

        self.cursor_reverse = False
        self.cursor_blink = False

    def move_cursor(self, index=None, row=None, column=None, force_load=False):
        """Load the address counter."""
        address = self._calculate_address(index=index, row=row, column=column)

        # TODO: Verify that the address is within range - exclude status line.

        return self._load_address_counter(address, force_load)

    def buffered_write(self, byte, index=None, row=None, column=None):
        if index is None:
            if row is None or column is None:
                raise ValueError('Either index or row and column is required')

            index = self._get_index(row, column)

        # TODO: Verify that index is within range.

        if self.buffer[index] == byte:
            return False

        self.buffer[index] = byte

        self.dirty.add(index)

        return True

    def flush(self):
        for (start_index, end_index) in self._get_dirty_ranges():
            self._flush_range(start_index, end_index)

    def clear(self, clear_status_line=False):
        """Clear the screen."""
        (rows, columns) = self.dimensions

        if clear_status_line:
            address = 0
            count = (rows + 1) * columns
        else:
            address = columns
            count = rows * columns

        self._write((b'\x00', count), address=address)

        # Update the buffer and dirty indicators to reflect the cleared screen.
        for index in range(rows * columns):
            self.buffer[index] = 0x00

        self.dirty.clear()

        self.move_cursor(row=0, column=0, force_load=True)

    def toggle_cursor_blink(self):
        self.cursor_blink = not self.cursor_blink

    def toggle_cursor_reverse(self):
        self.cursor_reverse = not self.cursor_reverse

    def _get_index(self, row, column):
        return (row * self.dimensions.columns) + column

    def _calculate_address(self, index=None, row=None, column=None):
        if index is not None:
            return self.dimensions.columns + index

        if row is not None and column is not None:
            return self.dimensions.columns + self._get_index(row, column)

        raise ValueError('Either index or row and column is required')

    def _calculate_address_after_write(self, address, count):
        if address is None:
            return None

        address += count

        (rows, columns) = self.dimensions

        # TODO: Determine the correct behavior here...
        if self.address_counter >= self._calculate_address((rows * columns) -
                                                           1):
            return None

        return address

    def _read_address_counter(self):
        hi = read_address_counter_hi(self.interface)
        lo = read_address_counter_lo(self.interface)

        return (hi << 8) | lo

    def _load_address_counter(self, address, force_load):
        if address == self.address_counter and not force_load:
            return False

        (hi, lo) = _split_address(address)
        (current_hi, current_lo) = _split_address(self.address_counter)

        if hi != current_hi or force_load:
            load_address_counter_hi(self.interface, hi)

        if lo != current_lo or force_load:
            load_address_counter_lo(self.interface, lo)

        self.address_counter = address

        return True

    def _get_dirty_ranges(self):
        if not self.dirty:
            return []

        # TODO: Implement multiple ranges with optimization.
        return [(self.dirty[0], self.dirty[-1])]

    def _flush_range(self, start_index, end_index):
        if self.logger.isEnabledFor(logging.DEBUG):
            self.logger.debug(
                f'Flushing changes for range {start_index}-{end_index}')

        data = self.buffer[start_index:end_index + 1]

        address = self._calculate_address(start_index)

        try:
            self._write(data, address=address)
        except Exception as error:
            # TODO: This could leave the address_counter incorrect.
            self.logger.error(f'Write error: {error}', exc_info=error)

        for index in range(start_index, end_index + 1):
            self.dirty.discard(index)

        return self.address_counter

    def _write(self, data, address=None, restore_original_address=False):
        if restore_original_address:
            original_address = self.address_counter

            if original_address is None:
                original_address = self._read_address_counter()

        if address is not None:
            self._load_address_counter(address, force_load=False)

        write_data(self.interface, data)

        if isinstance(address, tuple):
            length = len(data[0]) * data[1]
        else:
            length = len(data)

        self.address_counter = self._calculate_address_after_write(
            address, length)

        if restore_original_address:
            self._load_address_counter(original_address, force_load=True)