コード例 #1
0
def test_copy():
    temp = SortedSet(range(100))
    temp._reset(7)
    that = temp.copy()
    that.add(1000)
    assert len(temp) == 100
    assert len(that) == 101
コード例 #2
0
def test_copy():
    temp = SortedSet(range(100))
    temp._reset(7)
    that = temp.copy()
    that.add(1000)
    assert len(temp) == 100
    assert len(that) == 101
コード例 #3
0
class CardSet(_CardSetImpl):
    def __init__(self, *, elms=None):
        self._data = SortedSet(elms, key=operator.attrgetter('cost', 'name'))

    def contains(self, card_name: 'CardName'):
        ev = has_attr(name=card_name)
        return any(ev(card) for card in self.data)

    @property
    def data(self) -> SortedSet:
        return self._data.copy()
コード例 #4
0
def test_copy():
    temp = SortedSet(range(100), load=7)
    that = temp.copy()
    that.add(1000)
    assert len(that) == 101
    assert len(temp) == 101
コード例 #5
0
ファイル: fog.py プロジェクト: marcgarreau/py-trie
class HexaryTrieFog:
    """
    Keeps track of which parts of a trie have been verified to exist.

    Named after "fog of war" popular in video games like... Red Alert? IDK, I'm old.

    Object is immutable. Any changes, like marking a key prefix as complete, will
    return a new HexaryTrieFog object.
    """
    _unexplored_prefixes: GenericSortedSet[Nibbles]

    # INVARIANT: No unexplored prefix may start with another unexplored prefix
    #   For example, _unexplored_prefixes may not be {(1, 2), (1, 2, 3)}.

    def __init__(self) -> None:
        # Always start without knowing anything about a trie. The only unexplored
        #   prefix is the root prefix: (), which means the whole trie is unexplored.
        self._unexplored_prefixes = SortedSet({()})

    def __repr__(self) -> str:
        return f"HexaryTrieFog<{self._unexplored_prefixes!r}>"

    @property
    def is_complete(self) -> bool:
        return len(self._unexplored_prefixes) == 0

    def explore(
            self,
            old_prefix_input: NibblesInput,
            foggy_sub_segments: Sequence[NibblesInput]) -> 'HexaryTrieFog':
        """
        The fog lifts from the old prefix. This call returns a HexaryTrieFog that narrows
        down the unexplored key prefixes. from the old prefix to the indicated children.

        For example, if only the key prefix 0x12 is unexplored, then calling
        explore((1, 2), ((3,), (0xe, 0xf))) would mark large swaths of 0x12 explored, leaving only
        two prefixes as unknown: 0x123 and 0x12ef. To continue exploring those prefixes, navigate
        to them using traverse() or traverse_from().

        The sub_segments_input may be empty, which means the old prefix has been fully explored.
        """
        old_prefix = Nibbles(old_prefix_input)
        sub_segments = [Nibbles(segment) for segment in foggy_sub_segments]
        new_fog_prefixes = self._unexplored_prefixes.copy()

        try:
            new_fog_prefixes.remove(old_prefix)
        except KeyError:
            raise ValidationError(f"Old parent {old_prefix} not found in {new_fog_prefixes!r}")

        if len(set(sub_segments)) != len(sub_segments):
            raise ValidationError(
                f"Got duplicate sub_segments in {sub_segments} to HexaryTrieFog.explore()"
            )

        # Further validation that no segment is a prefix of another
        all_lengths = set(len(segment) for segment in sub_segments)
        if len(all_lengths) > 1:
            # The known use case of exploring nodes one at a time will never arrive in this
            #   validation check which might be slow. Leaf nodes have no sub segments,
            #   extension nodes have exactly one, and branch nodes have all sub_segments
            #   of length 1. If a new use case hits this verification, and speed becomes an issue,
            #   see https://github.com/ethereum/py-trie/issues/107
            for segment in sub_segments:
                shorter_lengths = [length for length in all_lengths if length < len(segment)]
                for check_length in shorter_lengths:
                    trimmed_segment = segment[:check_length]
                    if trimmed_segment in sub_segments:
                        raise ValidationError(
                            f"Cannot add {segment} which is a child of segment {trimmed_segment}"
                        )

        new_fog_prefixes.update([old_prefix + segment for segment in sub_segments])
        return self._new_trie_fog(new_fog_prefixes)

    def mark_all_complete(self, prefix_inputs: Sequence[NibblesInput]) -> 'HexaryTrieFog':
        """
        These might be leaves, or prefixes with 0 unknown keys within the range.

        This is equivalent to the following, but with better performance:

            result_fog = old_fog
            for complete_prefix in prefixes:
                result_fog = result_fog.explore(complete_prefix, ())
        """
        new_unexplored_prefixes = self._unexplored_prefixes.copy()
        for prefix in map(Nibbles, prefix_inputs):
            if prefix not in new_unexplored_prefixes:
                raise ValidationError(
                    f"When marking {prefix} complete, could not find in {new_unexplored_prefixes!r}"
                )

            new_unexplored_prefixes.remove(prefix)
        return self._new_trie_fog(new_unexplored_prefixes)

    def nearest_unknown(self, key_input: NibblesInput = ()) -> Nibbles:
        """
        Find the foggy prefix that is nearest to the supplied key.

        If prefixes are exactly the same distance to the left and right,
        then return the prefix on the right.

        :raises PerfectVisibility: if there are no foggy prefixes remaining
        """
        key = Nibbles(key_input)

        index = self._unexplored_prefixes.bisect(key)

        if index == 0:
            # If sorted set is empty, bisect will return 0
            # But it might also return 0 if the search value is lower than the lowest existing
            try:
                return self._unexplored_prefixes[0]
            except IndexError as exc:
                raise PerfectVisibility("There are no more unexplored prefixes") from exc
        elif index == len(self._unexplored_prefixes):
            return self._unexplored_prefixes[-1]
        else:
            nearest_left = self._unexplored_prefixes[index - 1]
            nearest_right = self._unexplored_prefixes[index]

            # is the left or right unknown prefix closer?
            left_distance = self._prefix_distance(nearest_left, key)
            right_distance = self._prefix_distance(key, nearest_right)
            if left_distance < right_distance:
                return nearest_left
            else:
                return nearest_right

    def nearest_right(self, key_input: NibblesInput) -> Nibbles:
        """
        Find the foggy prefix that is nearest on the right to the supplied key.

        :raises PerfectVisibility: if there are no foggy prefixes to the right
        """
        key = Nibbles(key_input)

        index = self._unexplored_prefixes.bisect(key)

        if index == 0:
            # If sorted set is empty, bisect will return 0
            # But it might also return 0 if the search value is lower than the lowest existing
            try:
                return self._unexplored_prefixes[0]
            except IndexError as exc:
                raise PerfectVisibility("There are no more unexplored prefixes") from exc
        else:
            nearest_left = self._unexplored_prefixes[index - 1]

            # always return nearest right, unless prefix of key is unexplored
            if key_starts_with(key, nearest_left):
                return nearest_left
            else:
                try:
                    # This can raise a IndexError if index == len(unexplored prefixes)
                    return self._unexplored_prefixes[index]
                except IndexError as exc:
                    raise FullDirectionalVisibility(
                        f"There are no unexplored prefixes to the right of {key}"
                    ) from exc

    @staticmethod
    @to_tuple
    def _prefix_distance(low_key: Nibbles, high_key: Nibbles) -> Iterable[int]:
        """
        How far are the two keys from each other, as a sequence of differences.
        The first non-zero distance must be positive, but the remaining distances may
        be negative. Distances are designed to be simply compared, like distance1 < distance2.

        The high_key must be higher than the low key, or the output distances are not
        guaranteed to be accurate.
        """
        for low_nibble, high_nibble in zip_longest(low_key, high_key, fillvalue=None):
            if low_nibble is None:
                final_low_nibble = 15
            else:
                final_low_nibble = low_nibble

            if high_nibble is None:
                final_high_nibble = 0
            else:
                final_high_nibble = high_nibble

            # Note: this might return a negative value. It's fine, because only the
            #   relative distance matters. For example (1, 2) and (2, 1) produce a
            #   distance of (1, -1). If the other reference point is (3, 1), making
            #   the distance to the middle (1, 0), then the "correct" thing happened.
            #   The (1, 2) key is a tiny bit closer to the (2, 1) key, and a tuple
            #   comparison of the distance will show it as a smaller distance.
            yield final_high_nibble - final_low_nibble

    @classmethod
    def _new_trie_fog(cls, unexplored_prefixes: SortedSet) -> 'HexaryTrieFog':
        """
        Convert a set of unexplored prefixes to a proper HexaryTrieFog object.
        """
        copy = cls()
        copy._unexplored_prefixes = unexplored_prefixes
        return copy

    def serialize(self) -> bytes:
        # encode nibbles to a bytes value, to compress this down a bit
        prefixes = [
            encode_nibbles(nibbles)
            for nibbles in self._unexplored_prefixes
        ]
        return f"HexaryTrieFog:{prefixes!r}".encode()

    @classmethod
    def deserialize(cls, encoded: bytes) -> 'HexaryTrieFog':
        serial_prefix = b'HexaryTrieFog:'
        if not encoded.startswith(serial_prefix):
            raise ValueError(f"Cannot deserialize this into HexaryTrieFog object: {encoded!r}")
        else:
            encoded_list = encoded[len(serial_prefix):]
            prefix_list = ast.literal_eval(encoded_list.decode())
            deserialized_prefixes = SortedSet(
                # decode nibbles from compressed bytes value, and validate each value in range(16)
                Nibbles(decode_nibbles(prefix))
                for prefix in prefix_list
            )
            return cls._new_trie_fog(deserialized_prefixes)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, HexaryTrieFog):
            return False
        else:
            return self._unexplored_prefixes == other._unexplored_prefixes
コード例 #6
0
class SchedSleep(BatsimScheduler):
    def onAfterBatsimInit(self):
        self.nb_completed_jobs = 0

        self.jobs_completed = []
        self.jobs_waiting = []

        self.sched_delay = 0.0

        #temps pour que les machines s'eteingnent, 0 = infini
        self.sleep_wait = 0
        # pourcentage max de machine en idle
        self.max_Idle = 0.25
        #Rallume des machines si le nombre de machine en idle est inférieur au max
        self.boot_Idle = True
        #tableau qui stock les requestcall pour éviter de lancer deux requestcall au même timestamp
        self.requestCall = SortedSet()
        #est-ce que le workload est fini
        self.end_Workload = False
        #tableau qui stock a quel timestamp la machine i doit s'arreter, si pas d'arret programmer la machine i est à -1
        self.machine_wait = [-1] * self.bs.nb_resources

        self.open_jobs = []

        self.computing_machines = SortedSet()
        self.idle_machines = SortedSet(range(self.bs.nb_resources))
        self.sleeping_machines = SortedSet()
        self.switching_ON_machines = SortedSet()
        self.switching_OFF_machines = SortedSet()

        self.machines_states = {
            int(i): State.Idle.value
            for i in range(self.bs.nb_resources)
        }
        print("machines_states", self.machines_states)

        print("machines_waiter", self.machine_wait)

    def scheduleJobs(self):
        """print('\n\n\n\n')
        print('open_jobs = ', self.open_jobs)

        print('computingM = ', self.computing_machines)
        print('idleM = ', self.idle_machines)
        print('sleepingM = ', self.sleeping_machines)
        print('switchingON_M = ', self.switching_ON_machines)
        print('switchingOFF_M = ', self.switching_OFF_machines)"""

        scheduled_jobs = []
        pstates_to_change = []
        loop = True

        # If there is a job to schedule
        while loop and self.open_jobs:
            job = self.open_jobs[0]
            nb_res_req = job.requested_resources

            if nb_res_req > self.bs.nb_resources:  # Job too big -> rejection
                sys.exit("Rejection unimplemented")

            # Job fits now -> allocation
            elif nb_res_req <= len(self.idle_machines):
                res = ProcSet(*self.idle_machines[:nb_res_req])
                job.allocation = res
                scheduled_jobs.append(job)
                for r in res:  # Machines' states update
                    self.machine_wait[
                        r] = -1  #on remet le compteur à 0 puisqu'on affecte une tache à la machine
                    self.idle_machines.remove(r)
                    self.computing_machines.add(r)
                    self.machines_states[r] = State.Computing.value
                self.open_jobs.remove(job)

            else:  # Job can fit on the machine, but not now
                loop = False
                #print("############ Job does not fit now ############")
                nb_not_computing_machines = self.bs.nb_resources - \
                    len(self.computing_machines)
                #print("nb_res_req = ", nb_res_req)
                #print("nb_not_computing_machines = ",
                #      nb_not_computing_machines)
                if nb_res_req <= nb_not_computing_machines:  # The job could fit if more machines were switched ON
                    # Let us switch some machines ON in order to run the job
                    nb_res_to_switch_ON = nb_res_req - \
                        len(self.idle_machines) - \
                        len(self.switching_ON_machines)
                    #print("nb_res_to_switch_ON = ", nb_res_to_switch_ON)
                    if nb_res_to_switch_ON > 0:  # if some machines need to be switched ON now
                        nb_switch_ON = min(nb_res_to_switch_ON,
                                           len(self.sleeping_machines))
                        if nb_switch_ON > 0:  # If some machines can be switched ON now
                            res = self.sleeping_machines[:nb_switch_ON]
                            for r in res:  # Machines' states update + pstate change request
                                self.sleeping_machines.remove(r)
                                self.switching_ON_machines.add(r)
                                self.machines_states[
                                    r] = State.SwitchingON.value
                                pstates_to_change.append(
                                    (PState.ComputeFast.value, (r, r)))
                else:  # The job cannot fit now because of other jobs
                    # Let us put all idle machines to sleep
                    pstates_to_change = self.SleepMachineControl()

        # if there is nothing to do, let us put all idle machines to sleep
        if not self.open_jobs:
            pstates_to_change = self.SleepMachineControl()
        """
        if not self.open_jobs:
            for r in self.idle_machines:
                self.switching_OFF_machines.add(r)
                self.machines_states[r] = State.SwitchingOFF.value
                pstates_to_change.append((PState.Sleep.value, (r, r)))
            self.idle_machines = SortedSet()
        """

        # update time
        self.bs.consume_time(self.sched_delay)

        #print(self.bs.time())

        #On récupère le temps du prochain éteignage et si il a pas déjà été programmer
        #on envoit un message à batsim pour nous reveiller à ce moment la
        if max(self.machine_wait) == -1:
            nextSleep = -1
        else:
            nextSleep = min(filter(lambda i: i > 0, self.machine_wait))
            if not (nextSleep in self.requestCall):
                self.bs.wake_me_up_at(nextSleep)
                self.requestCall.add(nextSleep)
                #print(self.machine_wait, nextSleep)

        # send to uds
        self.bs.execute_jobs(scheduled_jobs)
        for (val, (r1, r2)) in pstates_to_change:
            self.bs.set_resource_state(ProcSet(r1), val)

    def SleepMachineControl(self):
        pstates_to_change = []

        #nombre de machine en idle actuellement
        nb_idle_machine = len(self.idle_machines)
        for r in self.idle_machines.copy():
            #si la machine n'a pas de temps d'arret programmé
            if self.machine_wait[r] < 0:
                #si le nombre de machine en idle est supérieur au nombre max
                #ou qu'on est arrivé à la fin du workload
                #on programme la l'arret de la machine imédiatement
                if (nb_idle_machine > self.bs.nb_resources * self.max_Idle
                        or self.end_Workload):
                    self.machine_wait[r] = round(
                        self.bs.time()) - 1  #arret immédiat
                    nb_idle_machine -= 1
                #Sinon si sleep_wait n'est pas égal à 0 on programme l'arret dans sleep_wait
                elif self.sleep_wait != 0:
                    self.machine_wait[r] = round(
                        self.bs.time()) + self.sleep_wait  #arret retardé
            #Si la machine à un temps d'arret programmé et qu'il est inférrieur au temps actuelle alors on l'eteint
            if self.machine_wait[r] > 0 and self.machine_wait[r] <= round(
                    self.bs.time()):
                self.idle_machines.remove(r)
                self.machine_wait[r] = -1
                self.switching_OFF_machines.add(r)
                self.machines_states[r] = State.SwitchingOFF.value
                pstates_to_change.append((PState.Sleep.value, (r, r)))

        #Si le workload est pas fini et qu'on a boot_Idle à vrai
        if not (self.end_Workload) and self.boot_Idle:
            #On récupère le nombre de machine qu'on a besoin d'allumé
            nb_need_switch_on = round(self.bs.nb_resources * self.max_Idle -
                                      len(self.idle_machines) -
                                      len(self.switching_ON_machines))
            if nb_need_switch_on > 0:
                #on prend le minimun entre le nombre de machine eteinte et ce qu'on a besoin
                nb_switch_ON = min(nb_need_switch_on,
                                   len(self.sleeping_machines))
                #Si on a au moins une machine à allumer
                if nb_switch_ON > 0:
                    res = self.sleeping_machines[0:nb_switch_ON]
                    #on parcours les machines et on les allume
                    for r in res:
                        self.sleeping_machines.remove(r)
                        self.switching_ON_machines.add(r)
                        self.machines_states[r] = State.SwitchingON.value
                        pstates_to_change.append(
                            (PState.ComputeFast.value, (r, r)))

        return pstates_to_change

    def onNoMoreJobsInWorkloads(self):
        pstates_to_change = []
        self.end_Workload = True

        for r in self.idle_machines:
            self.idle_machines.remove(r)
            self.machine_wait[r] = -1
            self.switching_OFF_machines.add(r)
            self.machines_states[r] = State.SwitchingOFF.value
            pstates_to_change.append((PState.Sleep.value, (r, r)))

        for (val, (r1, r2)) in pstates_to_change:
            self.bs.set_resource_state(ProcSet(r1), val)

    def onRequestedCall(self):
        #print("request call, time:",self.bs.time())

        #print(self.idle_machines)
        pstates_to_change = self.SleepMachineControl()

        if max(self.machine_wait) == -1:
            nextSleep = -1
        else:
            nextSleep = min(filter(lambda i: i > 0, self.machine_wait))
            if not (nextSleep in self.requestCall):
                self.bs.wake_me_up_at(nextSleep)
                self.requestCall.add(nextSleep)
            #print(self.machine_wait, nextSleep)

        for (val, (r1, r2)) in pstates_to_change:
            self.bs.set_resource_state(ProcSet(r1), val)

    def onJobSubmission(self, job):
        #print("job:",job)
        if job.requested_resources > self.bs.nb_compute_resources:
            self.bs.reject_jobs(
                [job])  # This job requests more resources than the machine has
        else:
            self.open_jobs.append(job)
            self.scheduleJobs()

    def onJobCompletion(self, job):
        for res in job.allocation:
            self.idle_machines.add(res)
            self.computing_machines.remove(res)
            self.machines_states[res] = State.Idle.value
        self.scheduleJobs()

    def onMachinePStateChanged(self, machines, new_pstate):
        machine = machines[0]
        if (int(new_pstate) == PState.ComputeFast.value
            ) or (new_pstate == PState.ComputeMedium.value) or (
                int(new_pstate)
                == PState.ComputeSlow.value):  # switched to a compute pstate
            if self.machines_states[machine] == State.SwitchingON.value:
                self.switching_ON_machines.remove(machine)
                self.idle_machines.add(machine)
                self.machines_states[machine] = State.Idle.value
            else:
                sys.exit(
                    "Unhandled case: a machine switched to a compute pstate but was not switching ON"
                )
        elif int(new_pstate) == PState.Sleep.value:
            if self.machines_states[machine] == State.SwitchingOFF.value:
                self.switching_OFF_machines.remove(machine)
                self.sleeping_machines.add(machine)
                self.machines_states[machine] = State.Sleeping.value
            else:
                sys.exit(
                    "Unhandled case: a machine switched to a sleep pstate but was not switching OFF"
                )
        else:
            #print(new_pstate,PState.Sleep.value, new_pstate==PState.Sleep.value)
            sys.exit("Switched to an unhandled pstate: " + str(new_pstate))

        self.scheduleJobs()
コード例 #7
0
ファイル: mv_list_page.py プロジェクト: fengjixuchui/angr
class MVListPage(
    MemoryObjectSetMixin,
    PageBase,
):
    """
    MVListPage allows storing multiple values at the same location, thus allowing weak updates.

    Each store() may take a value or multiple values, and a "weak" parameter to specify if this store is a weak update
    or not.
    Each load() returns an iterator of all values stored at that location.
    """
    def __init__(self, memory=None, content=None, sinkhole=None, mo_cmp=None, **kwargs):
        super().__init__(**kwargs)

        self.content: List[Optional[Set[_MOTYPE]]] = content
        self.stored_offset = SortedSet()
        self._mo_cmp: Optional[Callable] = mo_cmp

        if content is None:
            if memory is not None:
                self.content: List[Optional[Set[_MOTYPE]]] = [None] * memory.page_size

        self.sinkhole: Optional[_MOTYPE] = sinkhole

    def copy(self, memo) -> 'MVListPage':
        o = super().copy(memo)
        o.content = list(self.content)
        o.sinkhole = self.sinkhole
        o.stored_offset = self.stored_offset.copy()
        o._mo_cmp = self._mo_cmp
        return o

    def load(self, addr, size=None, endness=None, page_addr=None, memory=None, cooperate=False,
             **kwargs) -> List[Tuple[int,_MOTYPE]]:
        result = [ ]
        last_seen = ...  # ;)

        # loop over the loading range. accumulate a result for each byte, but collapse results from adjacent bytes
        # using the same memory object
        for subaddr in range(addr, addr + size):
            items = self.content[subaddr]
            if items is None:
                items = { self.sinkhole } if self.sinkhole is not None else None
            if items != last_seen:
                if last_seen is None:
                    self._fill(result, subaddr, page_addr, endness, memory, **kwargs)
                result.append((subaddr + page_addr, items))
                last_seen = items

        if last_seen is None:
            self._fill(result, addr + size, page_addr, endness, memory, **kwargs)

        if not cooperate:
            result = self._force_load_cooperation(result, size, endness, memory=memory, **kwargs)
        return result

    def _fill(self, result, addr, page_addr, endness, memory, **kwargs):
        """
        Small utility function for behavior which is duplicated in load

        mutates result to generate a new memory object and replace the last entry in it, which is None. Then, it will
        insert the new memory object into self.content.
        """
        global_end_addr = addr + page_addr
        global_start_addr = result[-1][0]
        size = global_end_addr - global_start_addr
        new_ast = self._default_value(global_start_addr, size, name='%s_%x' % (memory.id, global_start_addr),
                                      key=(self.category, global_start_addr), memory=memory, **kwargs)
        new_item = SimMemoryObject(new_ast, global_start_addr, endness=endness,
                                   byte_width=memory.state.arch.byte_width if memory is not None else 8)
        subaddr_start = global_start_addr - page_addr
        for subaddr in range(subaddr_start, addr):
            self.content[subaddr] = { new_item }
            self.stored_offset.add(subaddr)
        result[-1] = (global_start_addr, new_item)

    def store(self, addr, data, size=None, endness=None, memory=None, cooperate=False, weak=False, **kwargs):
        if not cooperate:
            data = self._force_store_cooperation(addr, data, size, endness, memory=memory, **kwargs)

        data: Set[_MOTYPE]

        if size == len(self.content) and addr == 0:
            self.sinkhole = data
            self.content = [None] * len(self.content)
            self.stored_offset = SortedSet()
        else:
            if not weak:
                for subaddr in range(addr, addr + size):
                    self.content[subaddr] = set(data)
                    self.stored_offset.add(subaddr)
            else:
                for subaddr in range(addr, addr + size):
                    if self.content[subaddr] is None:
                        self.content[subaddr] = set(data)
                    else:
                        self.content[subaddr] |= data
                    self.stored_offset.add(subaddr)

    def merge(self, others: List['MVListPage'], merge_conditions, common_ancestor=None, page_addr: int = None,
              memory=None, changed_offsets: Optional[Set[int]]=None):

        if changed_offsets is None:
            changed_offsets = set()
            for other in others:
                changed_offsets |= self.changed_bytes(other, page_addr)

        all_pages: List['MVListPage'] = [self] + others
        if merge_conditions is None:
            merge_conditions = [None] * len(all_pages)

        merged_to = None
        merged_objects = set()
        merged_offsets = set()
        for b in sorted(changed_offsets):
            if merged_to is not None and not b >= merged_to:
                l.info("merged_to = %d ... already merged byte 0x%x", merged_to, b)
                continue
            l.debug("... on byte 0x%x", b)

            memory_objects = []
            unconstrained_in = []

            # first get a list of all memory objects at that location, and
            # all memories that don't have those bytes
            for sm, fv in zip(all_pages, merge_conditions):
                if sm._contains(b, page_addr):
                    l.info("... present in %s", fv)
                    for mo in sm.content[b]:
                        memory_objects.append((mo, fv))
                else:
                    l.info("... not present in %s", fv)
                    unconstrained_in.append((sm, fv))

            mos = set(mo for mo, _ in memory_objects)
            mo_bases = set(mo.base for mo, _ in memory_objects)
            mo_lengths = set(mo.length for mo, _ in memory_objects)
            endnesses = set(mo.endness for mo in mos)

            if not unconstrained_in and not (mos - merged_objects):
                continue

            # first, optimize the case where we are dealing with the same-sized memory objects
            if len(mo_bases) == 1 and len(mo_lengths) == 1 and not unconstrained_in and len(endnesses) == 1:
                the_endness = next(iter(endnesses))
                to_merge = [(mo.object, fv) for mo, fv in memory_objects]

                # Update `merged_to`
                mo_base = list(mo_bases)[0]
                mo_length = memory_objects[0][0].length
                size = mo_length - (page_addr + b - mo_base)
                merged_to = b + size

                merged_val = self._merge_values(to_merge, mo_length, memory=memory)
                if merged_val is None:
                    # merge_values() determines that we should not attempt to merge this value
                    continue

                # do the replacement
                # TODO: Implement in-place replacement instead of calling store()
                # new_object = self._replace_memory_object(our_mo, merged_val, page_addr, memory.page_size)

                first_value = True
                for v in merged_val:
                    self.store(b,
                               { SimMemoryObject(v, mo_base, endness=the_endness) },
                               size=size,
                               cooperate=True,
                               weak=not first_value,
                               )
                    first_value = False

                merged_offsets.add(b)

            else:
                # get the size that we can merge easily. This is the minimum of
                # the size of all memory objects and unallocated spaces.
                min_size = min([mo.length - (b + page_addr - mo.base) for mo, _ in memory_objects])
                for um, _ in unconstrained_in:
                    for i in range(0, min_size):
                        if um._contains(b + i, page_addr):
                            min_size = i
                            break
                merged_to = b + min_size
                l.info("... determined minimum size of %d", min_size)

                # Now, we have the minimum size. We'll extract/create expressions of that
                # size and merge them
                extracted = [(mo.bytes_at(page_addr + b, min_size), fv) for mo, fv in
                             memory_objects] if min_size != 0 else []
                created = [
                    (self._default_value(None, min_size, name="merge_uc_%s_%x" % (uc.id, b), memory=memory),
                     fv) for
                    uc, fv in unconstrained_in
                ]
                to_merge = extracted + created

                merged_val = self._merge_values(to_merge, min_size, memory=memory)
                if merged_val is None:
                    continue

                first_value = True
                for v in merged_val:
                    self.store(b,
                               { SimMemoryObject(v, page_addr + b, endness='Iend_BE') },
                               size=min_size,
                               endness='Iend_BE',
                               cooperate=True,
                               weak=not first_value,
                               )  # do not convert endianness again
                    first_value = False
                merged_offsets.add(b)

        self.stored_offset |= merged_offsets
        return merged_offsets

    def changed_bytes(self, other: 'MVListPage', page_addr: int = None):

        candidates: Set[int] = set()
        if self.sinkhole is None:
            candidates |= self.stored_offset
        else:
            for i in range(len(self.content)):
                if self._contains(i, page_addr):
                    candidates.add(i)

        if other.sinkhole is None:
            candidates |= other.stored_offset
        else:
            for i in range(len(other.content)):
                if other._contains(i, page_addr):
                    candidates.add(i)

        byte_width = 8  # TODO: Introduce self.state if we want to use self.state.arch.byte_width
        differences: Set[int] = set()
        for c in candidates:
            s_contains = self._contains(c, page_addr)
            o_contains = other._contains(c, page_addr)
            if not s_contains and o_contains:
                differences.add(c)
            elif s_contains and not o_contains:
                differences.add(c)
            else:
                if self.content[c] is None:
                    self.content[c] = { SimMemoryObject(self.sinkhole.bytes_at(page_addr + c, 1), page_addr + c,
                                                      byte_width=byte_width, endness='Iend_BE') }
                if other.content[c] is None:
                    other.content[c] = { SimMemoryObject(other.sinkhole.bytes_at(page_addr + c, 1), page_addr + c,
                                                       byte_width=byte_width, endness='Iend_BE') }
                if s_contains and self.content[c] != other.content[c]:
                    same = None
                    if self._mo_cmp is not None:
                        same = self._mo_cmp(self.content[c], other.content[c], page_addr + c, 1)
                    if same is None:
                        # Try to see if the bytes are equal
                        self_bytes = { mo.bytes_at(page_addr + c, 1) for mo in self.content[c] }
                        other_bytes = { mo.bytes_at(page_addr + c, 1) for mo in other.content[c] }
                        same = self_bytes == other_bytes

                    if same is False:
                        differences.add(c)
                else:
                    # this means the byte is in neither memory
                    pass

        return differences

    def _contains(self, off: int, page_addr: int):
        if off >= len(self.content):
            return False
        if self.content[off] is not None:
            return True
        if self.sinkhole is None:
            return False
        return self.sinkhole.includes(page_addr + off)

    def _replace_mo(self, old_mo: SimMemoryObject, new_mo: SimMemoryObject, page_addr: int,
                    page_size: int) -> SimMemoryObject:
        if self.sinkhole is old_mo:
            self.sinkhole = new_mo
        else:
            start, end = self._resolve_range(old_mo, page_addr, page_size)
            for i in range(start, end):
                s = { new_mo }
                if self.content[i - page_addr] is old_mo:
                    self.content[i - page_addr] = s
        return new_mo

    @staticmethod
    def _resolve_range(mo: SimMemoryObject, page_addr: int, page_size) -> Tuple[int, int]:
        start = max(mo.base, page_addr)
        end = min(mo.last_addr + 1, page_addr + page_size)
        if end <= start:
            l.warning("Nothing left of the memory object to store in SimPage.")
        return start, end

    def _get_objects(self, start: int, page_addr: int) -> Optional[List[SimMemoryObject]]:
        mos = self.content[start]
        if mos is None:
            return None
        lst = [ ]
        for mo in mos:
            if mo.includes(start + page_addr):
                lst.append(mo)
        if lst:
            return lst
        return None
コード例 #8
0
ファイル: history.py プロジェクト: ARoefer/kineverse
class History(object):
    def __init__(self, history=None, modification_history=None):
        # Dict var_name -> Timeline
        self.chunk_history = Timeline() if history is None else Timeline(
            history)
        if modification_history is None:
            self.modification_history = {}
            for c in self.chunk_history:
                for p in c.modifications:
                    if p not in self.modification_history:
                        self.modification_history[p] = Timeline()
                    self.modification_history[p].add(c)
                for p in c.dependencies:
                    if p not in self.modification_history:
                        raise Exception(
                            'Illegal sequence of operations was supplied! Referenced dependency {} does not exist at time {}'
                            .format(p, c.stamp))
                    self.modification_history[p][-1].dependents.add(c)
        else:
            self.modification_history = modification_history
        self.dirty_chunks = SortedSet()

    def __iter__(self):
        return iter(self.chunk_history)

    def __len__(self):
        return len(self.modification_history)

    def get_time_stamp(self, before=None, after=None):
        if before is not None:
            pos, succ = self.chunk_history.get_ceil(before) if type(
                before) != Chunk else self.chunk_history.get_ceil(before.stamp)
            return 0.5 * (succ.stamp + self.chunk_history[pos - 1].stamp
                          ) if pos > 0 else succ.stamp - 1
        elif after is not None:
            pos, succ = self.chunk_history.get_floor(after) if type(
                after) != Chunk else self.chunk_history.get_floor(after.stamp)
            return 0.5 * (succ.stamp +
                          self.chunk_history[pos + 1].stamp) if pos < len(
                              self.chunk_history) - 1 else succ.stamp + 1
        return self.chunk_history[-1].stamp + 1 if len(
            self.chunk_history) > 0 else 1

    @profile
    def _insert_modification(self, chunk, path):
        if path not in self.modification_history:
            self.modification_history[path] = Timeline()
        _, pred = self.modification_history[path].get_floor(chunk.stamp)
        if pred is not None:
            to_remove = set()
            for d in pred.dependents:
                # Fetch all dependents from predecessor which are going to depend on the new chunk
                # Save them as dependents and mark them as dirty
                if d.stamp > chunk.stamp:
                    dep_overlap_diff = d.dependencies.difference(
                        chunk.modifications)
                    # Is there at least one element overlap
                    if len(dep_overlap_diff) < len(d.dependencies):
                        chunk.dependents.add(d)
                        self.dirty_chunks.add(d)
                        # If there is no remaining overlap with pred anymore, remove d
                        if len(dep_overlap_diff.difference(
                                pred.modifications)) == len(dep_overlap_diff):
                            to_remove.add(d)
            pred.dependents -= to_remove
        self.modification_history[path].add(chunk)

    @profile
    def insert_chunk(self, chunk):
        for p in chunk.dependencies:
            if p not in self.modification_history:
                raise Exception(
                    'Chunk depends on attribute without history!\n Operation "{}" at {}\n Attribute: {}\n'
                    .format(chunk.operation.name, chunk.stamp, p))
            _, pred = self.modification_history[p].get_floor(chunk.stamp)
            if pred is None:
                raise Exception(
                    'Chunk at time {} executing "{}" depends on attributes with empty history! Attributes:\n  {}'
                    .format(
                        chunk.stamp, chunk.operation.name, '\n  '.join([
                            str(p) for p in chunk.dependencies
                            if p not in self.modification_history
                            or self.modification_history[p].get_floor(
                                chunk.stamp)[1] is None
                        ])))
            pred.dependents.add(chunk)

        for p in chunk.modifications:
            self._insert_modification(chunk, p)

        self.chunk_history.add(chunk)

    @profile
    def remove_chunk(self, chunk):
        for p in chunk.modifications:
            if self.modification_history[p][0] == chunk and len(
                    chunk.dependents) > 0 and max(
                        [p in c.dependencies for c in chunk.dependents]):
                raise Exception(
                    'Can not remove chunk at timestamp {} because it is the founding chunk in the history of {} and would create dangling dependencies.'
                    .format(chunk.stamp, p))

        for p in chunk.modifications:
            self.modification_history[p].discard(chunk)
            _, pred = self.modification_history[p].get_floor(chunk.stamp)
            # Copy dependents that depend on this variable to predecessor
            if pred is not None:
                pred.dependents.update(
                    {d
                     for d in chunk.dependents if p in d.dependencies})

        for p in chunk.dependencies:
            pos, pred = self.modification_history[p].get_floor(chunk.stamp)
            if pred is None:
                raise Exception(
                    'Chunk depends on attribute with empty history!')
            # It can happen that this chunk modifies the variable it depends on.
            # In this case it needs to be removed from the history and from
            if pred == chunk:
                pos -= 1
                pred = self.modification_history[p][pos]
            pred.dependents.discard(chunk)

        self.chunk_history.remove(chunk)
        self.dirty_chunks.update(chunk.dependents)

    @profile
    def replace_chunk(self, c_old, c_new):
        if c_old.stamp != c_new.stamp:
            raise Exception(
                'Can only replace chunk if stamps match. Stamps:\n Old: {:>8.3f}\n New: {:>8.3f}'
                .format(c_old.stamp, c_new.stamp))

        overlap = c_old.modifications.intersection(c_new.modifications)
        if len(overlap) != len(c_old.modifications):
            raise Exception(
                'Chunks can only be replaced by others with at least the same definition coverage. Missing variables:\n {}'
                .format('\n '.join(
                    sorted(c_old.modifications.difference(
                        c_new.modifications)))))

        new_deps = {
            p: self.modification_history[p].get_floor(c_new.stamp)[1]
            if p in self.modification_history else None
            for p in c_new.dependencies.difference(overlap)
        }
        if None in new_deps.values():
            raise Exception(
                'Replacement chunk at {} tries to depend on variables with insufficient histories. variables:\n {}'
                .format('\n '.join(sorted(new_deps.keys()))))

        for p in overlap:
            pos, _ = self.modification_history[p].get_floor(c_old.stamp)
            # If we are already here, we might as well remove old and establish new deps
            if p in c_old.dependencies:
                self.modification_history[p][pos - 1].dependents.discard(c_old)
            if p in c_new.dependencies:
                self.modification_history[p][pos - 1].dependents.add(c_new)
            self.modification_history[p].remove(c_old)
            self.modification_history[p].add(c_new)

        c_new.dependents = c_old.dependents.copy()
        self.flag_dirty(*c_new.dependents)

        # Remove old, non-modified deps
        for p in c_old.dependencies.difference(overlap):
            self.modification_history[p].get_floor(
                c_old.stamp)[1].dependents.remove(c_old)

        # Insert additional modifications
        for p in c_new.modifications.difference(overlap):
            self._insert_modification(c_new, p)

        for c in new_deps.values():
            c.dependents.add(c_new)

        self.chunk_history.remove(c_old)
        self.chunk_history.add(c_new)

    def get_chunk_by_index(self, idx):
        return self.chunk_history[idx]

    def get_chunk(self, stamp):
        return self.get_chunk_pos(stamp)[0]

    def get_chunk_pos(self, stamp):
        pos, chunk = self.chunk_history.get_floor(stamp)
        return (chunk,
                pos) if chunk is None or chunk.stamp == stamp else (None, None)

    def flag_dirty(self, *chunks):
        self.dirty_chunks.update(chunks)

    def flag_clean(self, *chunks):
        for c in chunks:
            self.dirty_chunks.discard(c)

    def expand_dirty_set(self):
        active_set = set(self.dirty_chunks)
        while len(active_set) > 0:
            a = active_set.pop()
            u = a.dependents.difference(self.dirty_chunks)
            active_set.update(u)
            self.dirty_chunks.update(u)

    def get_dirty(self):
        return self.dirty_chunks.copy()

    def get_subhistory(self, time):
        if len(self.chunk_history) > 0 and self.chunk_history[0].stamp >= time:
            chunks = self.chunk_history[:self.chunk_history.get_floor(time
                                                                      )[0] + 1]
            mod_history = {
                p: Timeline(h[:h.get_floor(time)])
                for p, h in self.modification_history.items()
                if h[0].stamp >= time
            }
            return History(chunks, mod_history)
        return History()

    def get_history_of(self, *paths):
        out = set()
        remaining = set()
        for p in paths:
            if p in self.modification_history:
                remaining.update(self.modification_history[p])

        while len(remaining) > 0:
            chunk = remaining.pop()
            out.add(chunk)
            for p in chunk.dependencies:
                pos, dep = self.modification_history[p].get_floor(chunk.stamp)
                if dep == chunk:  # Catch if predecessor is chunk itself
                    dep = self.modification_history[p][pos - 1]
                if dep not in out:
                    remaining.add(dep)

        return Timeline(out)

    def str_history_of(self, p):
        if p not in self.modification_history:
            raise Exception('Path {} has no history.'.format(p))
        return '\n'.join([
            '{:>8.3f} : {}'.format(chunk.stamp, str(chunk.op))
            for chunk in self.modification_history[p]
        ])

    def str_history(self):
        return '\n'.join([
            '{:>8.3f} : {}'.format(chunk.stamp, str(chunk.op))
            for chunk in self.chunk_history
        ])

    def __eq__(self, other):
        if isinstance(other, History):
            return self.chunk_history == other.chunk_history
        return False
コード例 #9
0
def test_copy():
    temp = SortedSet(range(100), load=7)
    that = temp.copy()
    that.add(1000)
    assert len(that) == 101
    assert len(temp) == 101
コード例 #10
0
class ARG(object):
    '''
    Ancestral Recombination Graph
    '''
    def __init__(self):
        self.nodes = {}
        self.roots = bintrees.AVLTree()  # root indexes
        self.rec = bintrees.AVLTree()  # arg rec parents nodes
        self.coal = bintrees.AVLTree()  # arg CA parent node
        self.num_ancestral_recomb = 0
        self.num_nonancestral_recomb = 0
        self.branch_length = 0
        self.nextname = 1  # next node index
        self.available_names = SortedSet()

    def __iter__(self):
        '''iterate over nodes in the arg'''
        return list(self.nodes)

    def __len__(self):
        '''number of nodes'''
        return len(self.nodes)

    def __getitem__(self, index):
        '''returns node by key: item'''
        return self.nodes[index]

    def __setitem__(self, index, node):
        '''adds a node to the ARG'''
        node.index = index
        self.add(node)

    def __contains__(self, index):
        '''if ARG contains node key '''
        return index in self.nodes

    def copy(self):
        '''return a copy of the ARG'''
        arg = ARG()
        for node in self.nodes.values():
            arg.nodes[node.index] = node.copy()
        # connect nodes
        for node in self.nodes.values():
            node2 = arg.__getitem__(node.index)
            if node.left_child != None:
                node2.left_child = arg.__getitem__(node.left_child.index)
                node2.right_child = arg.__getitem__(node.right_child.index)
            if node.left_parent != None:
                node2.left_parent = arg.__getitem__(node.left_parent.index)
                node2.right_parent = arg.__getitem__(node.right_parent.index)
        arg.roots = self.roots.copy()  # root indexes
        arg.rec = self.rec.copy()  # arg rec parents nodes
        arg.coal = self.coal.copy()  # arg CA parent node
        arg.num_ancestral_recomb = self.num_ancestral_recomb
        arg.num_nonancestral_recomb = self.num_nonancestral_recomb
        arg.branch_length = self.branch_length
        arg.nextname = self.nextname  # next node index
        arg.available_names = self.available_names.copy()
        return arg

    def equal(self, other):
        '''if self is equal with other (structural equality)
        TODO : complete this'''
        if self.__len__() != other.__len__():
            return False
        else:
            for node in self.nodes.values():
                if node.index not in other:
                    return False
                if not node.equal(other[node.index]):
                    return False
            return True

    def leaves(self, node=None):
        """
        Iterates over the leaves of the ARG.
        """
        if node is None:
            for node in self.nodes.values():
                if node.left_child == None:
                    yield node
        else:
            for node in self.preorder(node):
                if node.left_child == None:
                    yield node

    def preorder(self, node=None):
        """
        Iterates through nodes in preorder traversal.
        """
        visit = set()
        if node is None:
            node = self.__getitem__(self.roots.max_key())
        queue = [node]
        for node in queue:
            if node in visit:
                continue
            yield node
            visit.add(node)
            if node.left_child != None:
                queue.append(node.left_child)
                if node.left_child.index != node.right_child.index:
                    queue.append(node.right_child)

    def postorder(self, node=None):
        """
        Iterates through nodes in postorder traversal.
        """
        visit = collections.defaultdict(lambda: 0)
        queue = list(self.leaves(node))

        for node in queue:
            yield node
            if node.left_parent != None:
                visit[node.left_parent] += 1
                if node.left_parent.left_child.index != node.left_parent.right_child.index:
                    num_child = 2
                else:
                    num_child = 1
                # if all child has been visited then queue parent
                if visit[node.left_parent] == num_child:
                    queue.append(node.left_parent)
                if node.right_parent.index != node.left_parent.index:
                    visit[node.right_parent] += 1
                    # if all child has been visited then queue parent
                    if visit[node.right_parent] == num_child:
                        queue.append(node.right_parent)

    def set_roots(self):
        self.roots.clear()
        for node in self.nodes.values():
            if node.left_parent is None:
                self.roots[node.index] = node.index

    def get_times(self):
        '''return a sorted set of the ARG node.time'''
        times = SortedSet()
        for node in self.nodes.values():
            times.add(node.time)
        return times

    def get_higher_nodes(self, t):
        ''':return nodes.index of nodes with node.time >= t
        TODO: a more efficient search option
        '''
        return [key for key in self.nodes if self.nodes[key].time >= t]

    #==========================
    # node manipulation
    def alloc_segment(self,
                      left=None,
                      right=None,
                      node=None,
                      samples=bintrees.AVLTree(),
                      prev=None,
                      next=None):
        """
        alloc a new segment
        """
        s = Segment()
        s.left = left
        s.right = right
        s.node = node
        s.samples = samples
        s.next = next
        s.prev = prev
        return s

    def alloc_node(self,
                   index=None,
                   time=None,
                   left_child=None,
                   right_child=None):
        """
        alloc a new Node
        """
        node = Node(index)
        node.time = time
        node.first_segment = None
        node.left_child = left_child
        node.right_child = right_child
        node.left_parent = None
        node.right_parent = None
        node.breakpoint = None
        node.snps = bintrees.AVLTree()
        return node

    def store_node(self, segment, node):
        '''store node with segments: segment'''
        x = segment
        if x is not None:
            while x.prev is not None:
                x = x.prev
            s = self.alloc_segment(x.left, x.right, node, x.samples.copy())
            node.first_segment = s
            x.node = node
            x = x.next
            while x is not None:
                s = self.alloc_segment(x.left, x.right, node, x.samples.copy(),
                                       s)
                s.prev.next = s
                x.node = node
                x = x.next
        else:  #
            node.first_segment = None
        self.nodes[node.index] = node

    def copy_node_segments(self, node):
        '''
        copy the segments of a node,
        in CA event or Rec events, we need to copy the first node
        in order to make changes on them
        '''
        x = node.first_segment
        if x is None:
            return None
        else:
            assert x.prev is None
            s = self.alloc_segment(x.left, x.right, node, x.samples.copy())
            x.node = node
            x = x.next
            while x is not None:
                s = self.alloc_segment(x.left, x.right, node, x.samples.copy(),
                                       s)
                s.prev.next = s
                x.node = node
                x = x.next
            return s

    def get_available_names(self):
        '''get free names from 0 to max(nodes)'''
        self.available_names = SortedSet()
        current_names = SortedSet(self.__iter__())
        counter = 0
        prev = current_names[0]
        while counter < len(current_names):
            if current_names[counter] != prev + 1:
                self.available_names.update(
                    range(prev + 1, current_names[counter]))
            prev = current_names[counter]
            counter += 1

    def new_name(self):
        '''returns a new name for a node'''
        if self.available_names:
            name = self.available_names.pop()
        else:
            name = self.nextname
            self.nextname += 1
        return name

    def add(self, node):
        ''' add a ready node to the ARG:
        '''
        self.nodes[node.index] = node
        return node

    def rename(self, oldindex, newindex):
        '''renames a node in the ARG'''
        node = self.nodes[oldindex]
        node.index = newindex
        del self.nodes[oldindex]
        self.nodes[newindex] = node

    def total_branch_length(self):
        '''the ARG total branch length'''
        total_material = 0
        for node in self.nodes.values():
            if node.left_parent is not None:
                age = node.left_parent.time - node.time
                seg = node.first_segment
                while seg is not None:
                    total_material += ((seg.right - seg.left) * age)
                    seg = seg.next
        return total_material

    #=======================
    #spr related

    def detach(self, node, sib):
        '''
        Detaches a specified coalescence node from the rest of the ARG
        '''
        # print("Detach()",node.index, "sib", sib.index, "p",node.left_parent.index)
        assert node.left_parent.index == node.right_parent.index
        parent = node.left_parent
        sib.left_parent = parent.left_parent
        sib.right_parent = parent.right_parent
        sib.breakpoint = parent.breakpoint
        grandparent = parent.left_parent
        if grandparent is not None:
            grandparent.update_child(parent, sib)
            grandparent = parent.right_parent
            grandparent.update_child(parent, sib)

    def reattach(self, u, v, t, new_names):
        # Reattaches node u above node v at time t, new_names is a avltree of all
        #new nodes.index in a new ARG in mcmc
        assert t > v.time
        # assert v.left_parent == None or t < v.left_parent.time
        if u.left_parent is None:  # new_name
            new_name = self.new_name()
            new_names[new_name] = new_name
            # self.coal[new_name] = new_name # add the new CA parent to the ARG.coal
            parent = self.add(self.alloc_node(new_name))
            parent.left_child = u
            u.left_parent = parent
            u.right_parent = parent
        else:
            assert u.left_parent.index == u.right_parent.index
            parent = u.left_parent
        parent.time = t
        parent.breakpoint = v.breakpoint
        v.breakpoint = None
        parent.left_parent = v.left_parent
        grandparent = v.left_parent
        if grandparent is not None:
            grandparent.update_child(v, parent)
        parent.right_parent = v.right_parent
        grandparent = v.right_parent
        if grandparent is not None:
            grandparent.update_child(v, parent)
        v.left_parent = parent
        v.right_parent = parent
        if parent.left_child.index == u.index:
            parent.right_child = v
        else:
            parent.left_child = v
        return new_names

    def push_mutation_down(self, node, x):
        '''
        for a given node push the mutation (at x) as down as possible
        normally mutations automatically should stay at their
        lowest possible position. This might be useful for initial ARG
        '''
        block = False
        while not block:
            node, block = node.push_snp_down(x)

    def push_all_mutations_down(self, node):
        '''push down all mutations on node as low as possible'''
        snp_keys = [k for k in node.snps]
        for x in snp_keys:
            self.push_mutation_down(node, x)
        # iter = len(node.snps)
        # i = 0
        #
        # while iter > 0:
        #     x = node.snps[i]
        #     self.push_mutation_down(node, x)
        #     iter -= 1
        #     if node.snps and len(node.snps) > i:
        #         if node.snps[i] == x:
        #             i += 1

    def find_tmrca(self, node, x):
        '''
        check the parent of node to see
        if it is mrca for site x
        '''
        if node.left_parent is None:
            block = True
            return node, block
        elif node.left_parent.index is not node.right_parent.index:
            assert node.left_parent.contains(x) + node.right_parent.contains(
                x) == 1
            block = False
            if node.left_parent.contains(x):
                return node.left_parent, block
            else:
                return node.right_parent, block
        elif node.left_parent.contains(x):
            block = False
            return node.left_parent, block
        else:  # it is mrca for x
            block = True
            return node.left_parent, block

    def tmrca(self, x):
        '''tmrca for site x
        1. start from a leaf
        2. follow the path of x until its mrca
        '''
        node = self.__getitem__(0)
        block = False
        while not block:
            node, block = self.find_tmrca(node, x)
        return node.time

    def total_tmrca(self, sequence_length):
        '''
        return the tmrca of all the sites in the ARG
        '''
        break_points = self.breakpoints(only_ancRec=True, set=True)
        break_points.add(0)
        break_points.add(sequence_length)
        tot_tmrca = np.zeros(int(sequence_length))
        count = 0
        while count < len(break_points) - 1:
            x_tmrca = self.tmrca(break_points[count])
            tot_tmrca[int(break_points[count]):int(break_points[count +
                                                                1])] = x_tmrca
            count += 1
        return tot_tmrca

    def mean_tmrca(self, sequence_length):
        '''return a value for tmrca of the ARG, which is the mean over all trmrcas'''
        break_points = self.breakpoints(only_ancRec=True, set=True)
        break_points.add(0)
        break_points.add(sequence_length)
        tmrca_list = []
        count = 0
        while count < len(break_points) - 1:
            x_tmrca = self.tmrca(break_points[count])
            tmrca_list.append(
                x_tmrca *
                (int(break_points[count + 1]) - int(break_points[count])))
            count += 1
        return np.mean(tmrca_list)

    def allele_age(self):
        ''':return a pd df with four columns:
            1. site: the genomic position of the SNP
            2. recent age: the most recent age for the allele
            3. mid age: the midpoint of node age and its parent (tree node) time
            4. latest age: the latest time (back in time) for the mutation
            The df is sorted based on site.
         '''
        #find the nodes with mutations
        snp_nodes = []  # nodes with len(snps) > 0
        for node in self.nodes.values():
            if node.snps:
                snp_nodes.append(node)
        # now for each node and find age for each mut
        age_df = pd.DataFrame(
            columns=["site", "recent age", "mid age", "latest age"])
        for node in snp_nodes:
            # num_branches = collections.defaultdict(list)
            node_time = node.time
            for x in node.snps:
                parent_age = node.tree_node_age(x, return_parent_time=True)
                age_df.loc[age_df.shape[0]] = [
                    x, node_time, (node_time + parent_age) / 2, parent_age
                ]
        age_df.sort_values(by=['site'], ascending=True, inplace=True)
        age_df.reset_index(inplace=True, drop=True)
        return age_df

    def invisible_recombs(self):
        '''return the proportion of invisible recombs '''
        invis_count = 0
        for node in self.nodes.values():
            if node.breakpoint != None and node.is_invisible_recomb():
                invis_count += 1
        return invis_count / (self.num_ancestral_recomb +
                              self.num_nonancestral_recomb)

    #@property

    def breakpoints(self, only_ancRec=False, set=True):
        '''
        :param only_ancRec: only ancestral rec with repetition
        :param set: if set, only uqique posittions are returned
        :param invisible count the number of invisible recombs
        :return: either a list/set of all recombs
            or a list of anc rec that has repetition
        '''
        if set:
            br = SortedSet()
        else:
            br = SortedList()
        if not only_ancRec:
            for node in self.nodes.values():
                if node.breakpoint != None:
                    br.add(node.breakpoint)
        else:
            for node in self.nodes.values():
                if node.breakpoint != None and\
                        node.contains(node.breakpoint):#ancestral
                    br.add(node.breakpoint)
        return br

    #========== probabilites
    def log_likelihood(self, mutation_rate, data):
        '''
        log_likelihood of mutations on a given ARG up to a normalising constant
         that depends on the pattern of observed mutations, but not on the ARG
         or the mutation rate.
         Note after spr and berfore clean up we might have NAM lineages,
         this method covers take this into account.
         :param m : is number of snps
         '''
        snp_nodes = []  # nodes with len(snps) > 0
        total_material = 0
        number_of_mutations = 0
        #get total matereial and nodes with snps
        for node in self.nodes.values():
            if node.first_segment != None:
                assert node.left_parent != None
                age = node.left_parent.time - node.time
                seg = node.first_segment
                assert seg.prev == None
                while seg is not None:
                    total_material += ((seg.right - seg.left) * age)
                    seg = seg.next
                if node.snps:
                    number_of_mutations += len(node.snps)
                    snp_nodes.append(node)
        self.branch_length = total_material
        # print("number_of_mutations", number_of_mutations, "m", len(data))
        assert number_of_mutations == len(data)  # num of snps
        if mutation_rate == 0:
            if number_of_mutations == 0:
                ret = 0
            else:
                ret = -float("inf")
        else:
            ret = number_of_mutations * math.log(total_material * mutation_rate) -\
                (total_material * mutation_rate)
        # now calc prob of having this particular mutation pattern
        for node in snp_nodes:
            # num_branches = collections.defaultdict(list)
            for x in node.snps:
                potential_branch_length = node.tree_node_age(x)
                ret += math.log(potential_branch_length / total_material)
            # # verify the mutation is on the correct spot
            verify_mutation_node(node, data)
        return ret

    def log_prior(self,
                  sample_size,
                  sequence_length,
                  recombination_rate,
                  Ne,
                  NAM=True,
                  new_roots=False,
                  kuhner=False):
        '''
        probability of the ARG under coalescen with recombination
        this is after a move and before clean up. then there might be some
         extra NAM lineages, we ignore them.
         :param NAM: no-ancestral material node. If NAm node is allowed. note after spr and
            before clean up step there might be some NAM in the ARG which is ok. But after clean up
            or on the initial ARG there should not be any.
         '''
        # order nodes by time
        #TODO: find an efficient way to order nodes
        ordered_nodes = [
            v for k, v in sorted(self.nodes.items(),
                                 key=lambda item: item[1].time)
        ]
        number_of_lineages = sample_size
        number_of_links = number_of_lineages * (sequence_length - 1)
        number_of_nodes = self.__len__()
        counter = sample_size
        time = 0
        ret = 0
        rec_count = 0
        coal_count = 0
        roots = bintrees.AVLTree()
        new_coal = bintrees.AVLTree()
        if kuhner:
            self.rec.clear()
        self.num_ancestral_recomb = 0
        self.num_nonancestral_recomb = 0
        while counter < number_of_nodes:
            node = ordered_nodes[counter]
            assert node.time >= time  # make sure it is ordered]
            rate = (number_of_lineages * (number_of_lineages - 1) /
                    (4 * Ne)) + (number_of_links * (recombination_rate))
            # ret -= rate * (node.time - time)
            if node.left_child.index == node.right_child.index:  #rec
                assert node.left_child.first_segment != None
                assert node.left_child.left_parent.first_segment != None
                assert node.left_child.right_parent.first_segment != None
                ret -= rate * (node.time - time)
                gap = node.left_child.num_links()-\
                      (node.left_child.left_parent.num_links() +
                       node.left_child.right_parent.num_links())
                ret += math.log(recombination_rate)
                assert gap >= 1
                if gap == 1:
                    self.num_ancestral_recomb += 1
                else:
                    self.num_nonancestral_recomb += 1
                number_of_links -= gap
                number_of_lineages += 1
                if kuhner:  # add rec
                    self.rec[node.index] = node.index
                    self.rec[ordered_nodes[counter +
                                           1].index] = ordered_nodes[counter +
                                                                     1].index
                counter += 2
                time = node.time
                rec_count += 1
            elif node.left_child.first_segment != None and\
                        node.right_child.first_segment != None:
                ret -= rate * (node.time - time)
                ret -= math.log(2 * Ne)
                if node.first_segment == None:
                    node_numlink = 0
                    number_of_lineages -= 2
                    counter += 1
                    if new_roots:
                        roots[node.index] = node.index
                else:
                    node_numlink = node.num_links()
                    number_of_lineages -= 1
                    counter += 1
                lchild_numlink = node.left_child.num_links()
                rchild_numlink = node.right_child.num_links()
                number_of_links -= (lchild_numlink +
                                    rchild_numlink) - node_numlink
                time = node.time
                coal_count += 1
                if new_roots:
                    new_coal[node.index] = node.index
            else:
                counter += 1
            if not NAM:
                assert node.left_child.first_segment != None
                assert node.right_child.first_segment != None
        if new_roots:
            return ret, roots, new_coal
        else:
            return ret

    def dump(self, path=' ', file_name='arg.arg'):
        output = path + "/" + file_name
        with open(output, "wb") as file:
            pickle.dump(self, file)

    def load(self, path=' '):
        with open(path, "rb") as file:
            return pickle.load(file)

    def verify(self):
        '''
        verify arg:
        1. a node with parent must have seg
        2. a node with no parent a. must be in roots b. different child
        3. node.parent_time > node.time
        4. arg name == node.index
        5. recomb parent must have self.snps.empty()
        6. nodes with child = None must be leaf
        7. number coal + rec + roots check
        8. seg.samples is not empty, seg.left < seg.right
        '''
        for node in self.nodes.values():
            assert self.nodes[node.index].index == node.index
            if node.left_parent is None:  #roots
                if node.first_segment is not None:
                    print("in verrify node is ", node.index)
                    self.print_state()
                assert node.first_segment == None
                assert node.index in self.roots
                assert node.breakpoint == None
                assert node.left_child.index != node.right_child.index
                assert node.right_parent == None
                assert node.index in self.coal
                assert node.time > node.left_child.time
                assert node.time > node.right_child.time
            else:  # rest
                assert node.first_segment != None
                assert node.first_segment.prev == None
                assert node.get_tail().next == None
                assert node.index not in self.roots
                assert node.left_parent.time > node.time
                if node.left_child is None:  #leaves
                    assert node.right_child is None
                    assert node.time == 0
                if node.left_parent.index != node.right_parent.index:
                    assert node.breakpoint != None
                    assert node.left_parent.left_child.index ==\
                           node.left_parent.right_child.index
                    assert node.right_parent.left_child.index ==\
                        node.right_parent.right_child.index
                    assert node.right_parent.left_child.index == node.index
                    assert not node.left_parent.snps
                    assert not node.right_parent.snps
                    assert node.left_parent.time == node.right_parent.time
                    assert node.left_parent.index in self.rec
                    assert node.right_parent.index in self.rec
                    if node.left_parent.first_segment.left > node.right_parent.first_segment.left:
                        print("in verify node", node.index)
                        print("node.left_parent", node.left_parent.index)
                        print("node.right_parent", node.right_parent.index)
                    assert node.left_parent.first_segment.left < node.right_parent.first_segment.left
                else:
                    assert node.left_parent.index in self.coal
                    assert node.left_parent.left_child.index !=\
                           node.left_parent.right_child.index
                    assert node.breakpoint == None
            if node.first_segment is not None:
                seg = node.first_segment
                assert seg.prev is None
                while seg is not None:
                    assert seg.samples
                    assert seg.left < seg.right
                    assert seg.node.index == node.index
                    seg = seg.next

    def print_state(self):
        print("self.arg.coal", self.coal)
        print("self.arg.rec", self.rec)
        print("self.arg.roots", self.roots)
        print("node",
              "time",
              "left",
              "right",
              "l_chi",
              "r_chi",
              "l_par",
              "r_par",
              "l_bp",
              "snps",
              "fir_seg_sam",
              sep="\t")
        for j in self.nodes:
            node = self.__getitem__(j)
            if node.left_parent is not None or node.left_child is not None:
                s = node.first_segment
                if s is None:
                    print(j,
                          "%.5f" % node.time,
                          "root",
                          "root",
                          node.left_child.index,
                          node.right_child.index,
                          node.left_parent,
                          node.right_parent,
                          node.breakpoint,
                          node.snps,
                          None,
                          sep="\t")

                while s is not None:
                    l = s.left
                    r = s.right
                    if node.left_child is None:
                        print(j,
                              "%.5f" % node.time,
                              l,
                              r,
                              "Leaf",
                              "Leaf",
                              node.left_parent.index,
                              node.right_parent.index,
                              node.breakpoint,
                              node.snps,
                              s.samples,
                              sep="\t")  #
                    elif node.left_parent is None:
                        print(j,
                              "%.5f" % node.time,
                              l,
                              r,
                              node.left_child.index,
                              node.right_child.index,
                              "Root",
                              "Root",
                              node.breakpoint,
                              node.snps,
                              s.samples,
                              sep="\t")
                    else:
                        print(j,
                              "%.5f" % node.time,
                              l,
                              r,
                              node.left_child.index,
                              node.right_child.index,
                              node.left_parent.index,
                              node.right_parent.index,
                              node.breakpoint,
                              node.snps,
                              s.samples,
                              sep="\t")
                    s = s.next