Пример #1
0
class PrioritizedIntensity(object):
    _MIN_VALUE = 0.005

    def __init__(self):
        self._values = SortedDict()

    def set(self, value, priority=100):
        value = float(value)
        if value < self._MIN_VALUE and priority in self._values:
            del self._values[priority]
        else:
            self._values[priority] = value

    def eval(self):
        if not self._values:
            return 0.0
        return self._values[self._values.iloc[- 1]]

    def top_priority(self):
        if not self._values:
            return 0
        return self._values.keys()[len(self._values) - 1]

    def reset(self):
        self._values.clear()
Пример #2
0
def createTermIndex():
    sortTepDic = SortedDict()
    # Structure for each term
    #   sortTepDic['term']=({'DocId1':['Pos1','Pos2'],'DocId2':['Pos1','Pos2']},'termFreq','DocFreq')

    for root, dirs, files in os.walk(Contants.DATA_DIRECTORY_NAME, topdown=True):
        for name in files:
            file_name = os.path.join(root, name)
            #         'r' when the file will only be read
            #         'w' for only writing (an existing file with the same name will be erased)
            #         'a' opens the file for appending; any data written to the file is automatically added to the end.
            #         'r+' opens the file for both reading and writing.

            mode = "r"
            file_object = open(file_name, mode)
            DocId = os.path.split(file_name)[1]

            wordPos = 0
            for word in file_object.read().split():

                wordPos = wordPos + 1  # increment word location
                lamma = applyFilters(word)

                if lamma:
                    if lamma not in sortTepDic:
                        sortTepDic[lamma] = [{DocId: [wordPos]}, 1, 1]  # add a new term

                    else:

                        sortTepDic[lamma][1] = sortTepDic[lamma][1] + 1  # increment the term frequency

                        if DocId in sortTepDic[lamma][0]:
                            sortTepDic[lamma][0][DocId].append(
                                wordPos
                            )  # add new word position for the existing document
                        else:
                            sortTepDic[lamma][0][DocId] = [wordPos]  # add a new document ID and he word position
                            sortTepDic[lamma][2] = sortTepDic[lamma][2] + 1  # increment the document frequecy

    # covert lists to tuples
    for key in sortTepDic.keys():
        for DocId in sortTepDic[key][0]:
            sortTepDic[key][0][DocId] = tuple(sortTepDic[key][0][DocId])
        sortTepDic[key] = tuple(sortTepDic[key])

    Data.write_dataStruct_to_file(Contants.WORD_INDEX_FILE_NAME, sortTepDic)
    createLexicon(sortTepDic)
    createPostingList(sortTepDic)
Пример #3
0
class InMemoryStorage(object):
    def __init__(self):
        self.kvstore = SortedDict()  # hashtable

    def get(self, k):
        try:
            return self.kvstore[k]
        except:
            return 1

    def put(self, k, v):
        self.kvstore[k] = v
        return 0

    def delete(self, k):
        try:
            del self.kvstore[k]
            return 0
        except:
            return 1

    def split(self, section, keyspace_mid):
        """ delete one half of keystore for group split operation """
        midKey = None
        for key in self.kvstore.keys():  # TODO make more efficient for better performance
            if key > str(keyspace_mid):  # use iloc to estimate midpoint
                midKey = self.kvstore.index(key)
                break

        if section:  # section is either 0 or 1
            self.kvstore = self.kvstore.items()[midKey:]

        else:
            self.kvstore = self.kvstore.items()[:midKey]
        print(self.kvstore)
        return 0

    def save(self):  # need metadata here
        save_state("data/backup/db_copy.pkl", self.kvstore)

    def load(self):
        self.kvstore = load_state("data/backup/db_copy.pkl")
def test_keysview():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping[:13])
    keys = temp.keys()

    assert len(keys) == 13
    assert 'a' in keys
    assert list(keys) == [val for val, pos in mapping[:13]]
    assert keys[0] == 'a'
    assert list(reversed(keys)) == list(reversed(string.ascii_lowercase[:13]))
    assert keys.index('f') == 5
    assert keys.count('m') == 1
    assert keys.count('0') == 0
    assert keys.isdisjoint(['1', '2', '3'])

    temp.update(mapping[13:])

    assert len(keys) == 26
    assert 'z' in keys
    assert list(keys) == [val for val, pos in mapping]

    that = dict(mapping)

    that_keys = get_keysview(that)

    assert keys == that_keys
    assert not (keys != that_keys)
    assert not (keys < that_keys)
    assert not (keys > that_keys)
    assert keys <= that_keys
    assert keys >= that_keys

    assert list(keys & that_keys) == [val for val, pos in mapping]
    assert list(keys | that_keys) == [val for val, pos in mapping]
    assert list(keys - that_keys) == []
    assert list(keys ^ that_keys) == []

    keys = SortedDict(mapping[:2]).keys()
    assert repr(keys) == "SortedKeysView(SortedDict({'a': 0, 'b': 1}))"
Пример #5
0
def plotWidth(dwdictX,fname,nameX,mX,cuts):
   sorted_dwdictX = SortedDict(dwdictX)
   n = len(sorted_dwdictX)-1
   x = array('d',sorted_dwdictX.keys())
   y = array('d',sorted_dwdictX.values())
   gwX = TGraph(n,x,y)
   gwX.SetName("gwX")
   gwX.SetTitle("")
   gwX.GetXaxis().SetTitle("tan#beta")
   gwX.GetYaxis().SetTitle("#Gamma_{#it{"+nameX+"}}/#it{m}_{#it{"+nameX+"}} [%]")
   gwX.SetLineColor(ROOT.kBlack)
   gwX.SetMarkerColor(ROOT.kBlack)
   gwX.SetMarkerStyle(20)
   gwX.SetMarkerSize(0.5)

   ptxt = TPaveText(0.62,0.70,0.87,0.87,"NDC")
   ptxt.SetFillStyle(4000) #will be transparent
   ptxt.SetFillColor(0)
   ptxt.SetTextFont(42)
   ptxt.SetBorderSize(0)
   ptxt.AddText("sin(#beta-#alpha)=1")
   ptxt.AddText("#it{m}_{#it{"+nameX+"}}="+str(mX)+" GeV")

   c = TCanvas("c","c",600,600)
   c.cd()
   c.SetLogx()
   c.SetLogy()
   c.SetGridx()
   c.SetGridy()
   c.SetTicks(1,1)
   c.Draw()
   # gwX.Draw("p")
   gwX.Draw()
   ptxt.Draw("same")
   c.Modified()
   c.Update()
   c.SaveAs(fname)
Пример #6
0
class FederationRemoteSendQueue(object):
    """A drop in replacement for TransactionQueue"""
    def __init__(self, hs):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        self.presence_map = {
        }  # Pending presence map user_id -> UserPresenceState
        self.presence_changed = SortedDict()  # Stream position -> user_id

        self.keyed_edu = {}  # (destination, key) -> EDU
        self.keyed_edu_changed = SortedDict(
        )  # stream position -> (destination, key)

        self.edus = SortedDict()  # stream position -> Edu

        self.failures = SortedDict(
        )  # stream position -> (destination, Failure)

        self.device_messages = SortedDict()  # stream position -> destination

        self.pos = 1
        self.pos_time = SortedDict()

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name, queue):
            LaterGauge(
                "synapse_federation_send_queue_%s_size" % (queue_name, ), "",
                [], lambda: len(queue))

        for queue_name in [
                "presence_map",
                "presence_changed",
                "keyed_edu",
                "keyed_edu_changed",
                "edus",
                "failures",
                "device_messages",
                "pos_time",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self):
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self):
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete):
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_changed.keys()
            i = self.presence_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_changed[key]

            user_ids = set(user_id
                           for uids in itervalues(self.presence_changed)
                           for user_id in uids)

            to_del = [
                user_id for user_id in self.presence_map
                if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            to_del = [
                edu_key for edu_key in self.keyed_edu
                if edu_key not in live_keys
            ]
            for edu_key in to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

            # Delete things out of failure map
            keys = self.failures.keys()
            i = self.failures.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.failures[key]

            # Delete things out of device map
            keys = self.device_messages.keys()
            i = self.device_messages.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.device_messages[key]

    def notify_new_events(self, current_id):
        """As per TransactionQueue"""
        # We don't need to replicate this as it gets sent down a different
        # stream.
        pass

    def send_edu(self, destination, edu_type, content, key=None):
        """As per TransactionQueue"""
        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    def send_presence(self, states):
        """As per TransactionQueue

        Args:
            states (list(UserPresenceState))
        """
        pos = self._next_pos()

        # We only want to send presence for our own users, so lets always just
        # filter here just in case.
        local_states = list(
            filter(lambda s: self.is_mine_id(s.user_id), states))

        self.presence_map.update(
            {state.user_id: state
             for state in local_states})
        self.presence_changed[pos] = [state.user_id for state in local_states]

        self.notifier.on_new_replication_data()

    def send_failure(self, failure, destination):
        """As per TransactionQueue"""
        pos = self._next_pos()

        self.failures[pos] = (destination, str(failure))
        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination):
        """As per TransactionQueue"""
        pos = self._next_pos()
        self.device_messages[pos] = destination
        self.notifier.on_new_replication_data()

    def get_current_token(self):
        return self.pos - 1

    def federation_ack(self, token):
        self._clear_queue_before_pos(token)

    def get_replication_rows(self,
                             from_token,
                             to_token,
                             limit,
                             federation_ack=None):
        """Get rows to be sent over federation between the two tokens

        Args:
            from_token (int)
            to_token(int)
            limit (int)
            federation_ack (int): Optional. The position where the worker is
                explicitly acknowledged it has handled. Allows us to drop
                data from before that point
        """
        # TODO: Handle limit.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []

        # There should be only one reader, so lets delete everything its
        # acknowledged its seen.
        if federation_ack:
            self._clear_queue_before_pos(federation_ack)

        # Fetch changed presence
        i = self.presence_changed.bisect_right(from_token)
        j = self.presence_changed.bisect_right(to_token) + 1
        dest_user_ids = [
            (pos, user_id)
            for pos, user_id_list in self.presence_changed.items()[i:j]
            for user_id in user_id_list
        ]

        for (key, user_id) in dest_user_ids:
            rows.append((key, PresenceRow(state=self.presence_map[user_id], )))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in iteritems(keyed_edus):
            rows.append((pos,
                         KeyedEduRow(
                             key=edu_key,
                             edu=self.keyed_edu[(destination, edu_key)],
                         )))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Fetch changed failures
        i = self.failures.bisect_right(from_token)
        j = self.failures.bisect_right(to_token) + 1
        failures = self.failures.items()[i:j]

        for (pos, (destination, failure)) in failures:
            rows.append(
                (pos, FailureRow(
                    destination=destination,
                    failure=failure,
                )))

        # Fetch changed device messages
        i = self.device_messages.bisect_right(from_token)
        j = self.device_messages.bisect_right(to_token) + 1
        device_messages = {v: k for k, v in self.device_messages.items()[i:j]}

        for (destination, pos) in iteritems(device_messages):
            rows.append((pos, DeviceRow(destination=destination, )))

        # Sort rows based on pos
        rows.sort()

        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
Пример #7
0
class ComparisonHelper(object):
    """ Helps comparisons vs. the observed open stoma """

    def __init__(self, reference_dimensions, optimisation_def, open_pressure ):
        """
        :param reference_dimensions: The observed dimensions
        :type reference_dimensions: dict
        :param optimisation_def: The keys and any aliases for the optimisation. The aliases provide mapping(s) for any
        names that differ
        :type optimisation_def: dict
        :param open_pressure: The guard cell turgor pressure when the stoma is fully open
        :type open_pressure: float
        """

        self.reference_dimensions = SortedDict( reference_dimensions )
        self.optimisation_keys = sorted( optimisation_def[ 'keys' ] )
        self.key_aliases = optimisation_def[ 'aliases' ]
        self.open_pressure = open_pressure

        # make sure the keys on which the optimisation will be performed are in the observed set
        if len( set( self.optimisation_keys ) - set( reference_dimensions.keys() ) ) > 0:
            raise KeyError( "One or more of the optimisation keys {} is/are missing from the reference dimensions {}".
                            format( set( self.optimisation_keys ), set( self.reference_dimensions.keys() ) ) )

        # make sure any key aliases for the optimisation are in the observed set
        if len( set( self.key_aliases.keys() ) - set( reference_dimensions.keys() ) ) > 0:
            raise KeyError( "One or more of the key aliases {} is/are missing from the reference dimensions {}".
                            format( set( self.key_aliases.keys() ), set( self.reference_dimensions.keys() ) ) )

    @property
    def simulation_keys(self):
        """
        Get the keys produced by the simulation - names are set in the configuration file
        :return: list
        """
        return [ self.key_aliases[ key ] if key in self.key_aliases else key for key in self.optimisation_keys ]

    def perform_comparison(self, state_pressure, state_data):
        """
        Calculate the metric and percent difference to each measurement

        :param state_pressure: pressure at the simulation state
        :type state_pressure: float
        :param state_data: the data extracted for the state
        :type state_data: dict

        :return: Each tuple comprises 2 items: a label and then its numerical value
        :rtype: list of tuple
        """

        ref_dimensions = [ self.reference_dimensions[ key ] for key in self.optimisation_keys ]
        sim_dimensions = [ state_data[ key ] for key in self.simulation_keys ]

        # the difference will be -1 if the simulation value is zero - happens to pore-width when the stoma closes
        pc_diffs = [ sim / ref - 1.0 for ref, sim in zip( ref_dimensions, sim_dimensions ) ]

        # metric is 'sum [ ln( predict_i / actual_i )**2 ]
        metric_raw = sum( log( max( abs( 1 + pc_diff ), 1e-5 ) ) ** 2 for pc_diff in pc_diffs )
        metric_raw = sqrt( metric_raw )

        dp = self.open_pressure - state_pressure

        # weight the metric to the end time/pressure
        metric_weighted = metric_raw + dp

        result_keys = [ 'metric', 'metric_raw', 'dp' ] + [ 'pc-{}'.format( k ) for k in self.simulation_keys ]
        result_vals = [ metric_weighted, metric_raw, dp ] + pc_diffs

        return zip( result_keys, result_vals )
Пример #8
0
print("Keys are: ", end = "  ")
for k in d1.keys():
    print(k, end="  ")
print()

print("key-value pairs are ", end="  ")
for k, v in d1.items():
    print(k,'-', v, end="    ")
print()

print("clear all in d1.clear() ", d1.clear());      print("to delete the entire dict, use del d1");     del d1;
print()

# sortedcontainers
from sortedcontainers import SortedDict
sd = SortedDict(zip('dglkjuc', range(6)));      print("Created a SortedDict and all keys are in sorted order: \n", sd, list(sd.keys()))
del sd['g'];    sd['qq'] = 11;      print("Deletion and Addition, the keys are still sorted:\n", sd, list(sd.keys()))
print()

####################################################################################################################################
# Sets are collections of unique elements and you cannot order them. But you can use set operations on them
s1 = set([1, 2, 2, 4, 5, 6, 6]);     print("Creating a set using set(). Note there is no repeat element", s1);
sc = { x**2 for x in range(8) if x > 3};     print("Using set comprehension to create a set: ", sc)
s2 = set(aaa);      print("type cast a list to create a set s2=set(aaa) is ", s2);    s2.add("added");  print("add() an item: ", s2)
s2.remove('added'); print("now remove() the member item (not index) just added: ", s2); print("the len(s2) of the set is now ", len(s2))
print("check membership 'test' in s2 ", 'test' in s2);      print("check membership not in the list: 'test' not in s2 ", 'test' not in s2)
print("s2.pop() random item from s2 set ", s2.pop());       print("to s2.clear() to delete all items in the set s2 ", s2.clear())
s3 = sorted(s1, reverse=True);        print("you can NOT use s1.sort() for sorting, to sort, use sorted(s1, reverse=True) a set: ", s3);

# set math set operations
s2 = set(aaa);  print("s1 and s2 are ", s1, s2);    
def test_keys():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert list(temp.keys()) == [key for key, pos in mapping]
Пример #10
0
class Topics:
    """
A class that manages a collection of `Topic`s.

    """
    def __init__(self):
        self.logger = getLogger('topics')
        self.logger.info('started session')
        self.clear()

    def clear(self):
        self.logger.info('Cleared all topics and received data')
        self.topic_list = SortedDict()
        self.transfers = dict()

    def create(self, topic, source='remote'):
        # Create the topic if it doesn't exist already
        if not topic in self.topic_list:
            self.topic_list[topic] = Topic(topic,source=source)
            self.logger.info('new:topic ' + topic)

    def process(self, topic, payload, options=None):
        # Create the topic if it doesn't exist already
        self.create(topic)

        # Add the new sample
        self.topic_list[topic].new_sample(payload,options)

        # logging
        if options:
            self.logger.debug('new sample | {0} [{1}] {2}'.format(topic, options['index'], payload))
        else:
            self.logger.debug('new sample | {0} {1}'.format(topic, payload))

        # If there is an active transfer, transfer received data to the queue
        if topic in self.transfers:
            # If transfer requires indexed data, check there is an index
            if self.transfers[topic]['type'] == 'indexed' and options is not None:
                x = options['index']
                self.transfers[topic]['queue'].put([x, payload])
            # For linear data, provide sample id for x and payload for y
            elif self.transfers[topic]['type'] == 'linear':
                x = self.transfers[topic]['lastindex']
                self.transfers[topic]['queue'].put([x, payload])
                self.transfers[topic]['lastindex'] += 1

    def ls(self,source='remote'):
        if source is None:
            return sorted([t.name for t in self.topic_list.keys()])
        else:
            return sorted([t.name for t in self.topic_list.values() if t.source == source])

    def samples(self,topic,amount=1):
        if not topic in self.topic_list:
            return None

        if amount == 0 or amount is None:
            return self.topic_list[topic].raw

        return self.topic_list[topic].raw[-amount:]

    def count(self,topic):
        if not topic in self.topic_list:
            return 0

        return len(self.topic_list[topic].raw)

    def exists(self,topic):
        return topic in self.topic_list

    def transfer(self, topic, queue, transfer_type = "linear"):
        # If the topic data is not already transfered to some queue
        if not topic in self.transfers:
            self.transfers[topic] = dict()
            self.transfers[topic]['queue'] = queue
            self.transfers[topic]['lastindex'] = 0
            self.transfers[topic]['type'] = transfer_type

            self.logger.info('start transfer | {0}'.format(topic))

            # If there is already existing data under the topic
            if topic in self.topic_list:
                if transfer_type == 'indexed':
                    for key, value in self.topic_list[topic].indexes.iteritems():
                        queue.put([key, value])
                elif transfer_type == 'linear':
                    for item in self.topic_list[topic].raw:
                        queue.put([self.transfers[topic]['lastindex'], item])
                        self.transfers[topic]['lastindex'] += 1

    def untransfer(self,topic):
        # If the topic data is already transfered to some queue
        if topic in self.transfers:
            # Remove it from the transfer list
            del self.transfers[topic]
            self.logger.info('stop transfer | {0}'.format(topic))

    def intransfer(self,topic):
        return topic in self.transfers

    def has_indexed_data(self,topic):
        return self.topic_list[topic].has_indexed_data()
Пример #11
0
class FederationRemoteSendQueue(object):
    """A drop in replacement for FederationSender"""
    def __init__(self, hs):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        # We may have multiple federation sender instances, so we need to track
        # their positions separately.
        self._sender_instances = hs.config.worker.federation_shard_config.instances
        self._sender_positions = {}

        # Pending presence map user_id -> UserPresenceState
        self.presence_map = {}  # type: Dict[str, UserPresenceState]

        # Stream position -> list[user_id]
        self.presence_changed = SortedDict(
        )  # type: SortedDict[int, List[str]]

        # Stores the destinations we need to explicitly send presence to about a
        # given user.
        # Stream position -> (user_id, destinations)
        self.presence_destinations = (
            SortedDict())  # type: SortedDict[int, Tuple[str, List[str]]]

        # (destination, key) -> EDU
        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]

        # stream position -> (destination, key)
        self.keyed_edu_changed = (SortedDict()
                                  )  # type: SortedDict[int, Tuple[str, tuple]]

        self.edus = SortedDict()  # type: SortedDict[int, Edu]

        # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
        self.pos = 1

        # map from stream ID to the time that stream entry was generated, so that we
        # can clear out entries after a while
        self.pos_time = SortedDict()  # type: SortedDict[int, int]

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name, queue):
            LaterGauge(
                "synapse_federation_send_queue_%s_size" % (queue_name, ),
                "",
                [],
                lambda: len(queue),
            )

        for queue_name in [
                "presence_map",
                "presence_changed",
                "keyed_edu",
                "keyed_edu_changed",
                "edus",
                "pos_time",
                "presence_destinations",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self):
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self):
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete):
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_changed.keys()
            i = self.presence_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_changed[key]

            user_ids = {
                user_id
                for uids in self.presence_changed.values() for user_id in uids
            }

            keys = self.presence_destinations.keys()
            i = self.presence_destinations.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_destinations[key]

            user_ids.update(
                user_id for user_id, _ in self.presence_destinations.values())

            to_del = [
                user_id for user_id in self.presence_map
                if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            keys_to_del = [
                edu_key for edu_key in self.keyed_edu
                if edu_key not in live_keys
            ]
            for edu_key in keys_to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

    def notify_new_events(self, current_id):
        """As per FederationSender"""
        # We don't need to replicate this as it gets sent down a different
        # stream.
        pass

    def build_and_send_edu(self, destination, edu_type, content, key=None):
        """As per FederationSender"""
        if destination == self.server_name:
            logger.info("Not sending EDU to ourselves")
            return

        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    def send_read_receipt(self, receipt):
        """As per FederationSender

        Args:
            receipt (synapse.types.ReadReceipt):
        """
        # nothing to do here: the replication listener will handle it.
        return defer.succeed(None)

    def send_presence(self, states):
        """As per FederationSender

        Args:
            states (list(UserPresenceState))
        """
        pos = self._next_pos()

        # We only want to send presence for our own users, so lets always just
        # filter here just in case.
        local_states = list(
            filter(lambda s: self.is_mine_id(s.user_id), states))

        self.presence_map.update(
            {state.user_id: state
             for state in local_states})
        self.presence_changed[pos] = [state.user_id for state in local_states]

        self.notifier.on_new_replication_data()

    def send_presence_to_destinations(self, states, destinations):
        """As per FederationSender

        Args:
            states (list[UserPresenceState])
            destinations (list[str])
        """
        for state in states:
            pos = self._next_pos()
            self.presence_map.update(
                {state.user_id: state
                 for state in states})
            self.presence_destinations[pos] = (state.user_id, destinations)

        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination):
        """As per FederationSender"""
        # We don't need to replicate this as it gets sent down a different
        # stream.

    def get_current_token(self):
        return self.pos - 1

    def federation_ack(self, instance_name, token):
        if self._sender_instances:
            # If we have configured multiple federation sender instances we need
            # to track their positions separately, and only clear the queue up
            # to the token all instances have acked.
            self._sender_positions[instance_name] = token
            token = min(self._sender_positions.values())

        self._clear_queue_before_pos(token)

    async def get_replication_rows(
            self, instance_name: str, from_token: int, to_token: int,
            target_row_count: int
    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
        """Get rows to be sent over federation between the two tokens

        Args:
            instance_name: the name of the current process
            from_token: the previous stream token: the starting point for fetching the
                updates
            to_token: the new stream token: the point to get updates up to
            target_row_count: a target for the number of rows to be returned.

        Returns: a triplet `(updates, new_last_token, limited)`, where:
           * `updates` is a list of `(token, row)` entries.
           * `new_last_token` is the new position in stream.
           * `limited` is whether there are more updates to fetch.
        """
        # TODO: Handle target_row_count.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []  # type: List[Tuple[int, BaseFederationRow]]

        # Fetch changed presence
        i = self.presence_changed.bisect_right(from_token)
        j = self.presence_changed.bisect_right(to_token) + 1
        dest_user_ids = [
            (pos, user_id)
            for pos, user_id_list in self.presence_changed.items()[i:j]
            for user_id in user_id_list
        ]

        for (key, user_id) in dest_user_ids:
            rows.append((key, PresenceRow(state=self.presence_map[user_id])))

        # Fetch presence to send to destinations
        i = self.presence_destinations.bisect_right(from_token)
        j = self.presence_destinations.bisect_right(to_token) + 1

        for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
            rows.append((
                pos,
                PresenceDestinationsRow(state=self.presence_map[user_id],
                                        destinations=list(dests)),
            ))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in keyed_edus.items():
            rows.append((
                pos,
                KeyedEduRow(key=edu_key,
                            edu=self.keyed_edu[(destination, edu_key)]),
            ))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Sort rows based on pos
        rows.sort()

        return (
            [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
            to_token,
            False,
        )
Пример #12
0
class Topics:
    """
A class that manages a collection of `Topic`s.

    """
    def __init__(self):
        self.logger = getLogger('topics')
        self.logger.info('started session')
        self.clear()

    def clear(self):
        self.logger.info('Cleared all topics and received data')
        self.topic_list = SortedDict()
        self.transfers = dict()

    def create(self, topic, source='remote'):
        # Create the topic if it doesn't exist already
        if not topic in self.topic_list:
            self.topic_list[topic] = Topic(topic, source=source)
            self.logger.info('new:topic ' + topic)

    def process(self, topic, payload, options=None):
        # Create the topic if it doesn't exist already
        self.create(topic)

        # Add the new sample
        self.topic_list[topic].new_sample(payload, options)

        # logging
        if options:
            self.logger.debug('new sample | {0} [{1}] {2}'.format(
                topic, options['index'], payload))
        else:
            self.logger.debug('new sample | {0} {1}'.format(topic, payload))

        # If there is an active transfer, transfer received data to the queue
        if topic in self.transfers:
            # If transfer requires indexed data, check there is an index
            if self.transfers[topic][
                    'type'] == 'indexed' and options is not None:
                x = options['index']
                self.transfers[topic]['queue'].put([x, payload])
            # For linear data, provide sample id for x and payload for y
            elif self.transfers[topic]['type'] == 'linear':
                x = self.transfers[topic]['lastindex']
                self.transfers[topic]['queue'].put([x, payload])
                self.transfers[topic]['lastindex'] += 1

    def ls(self, source='remote'):
        if source is None:
            return sorted([t.name for t in self.topic_list.keys()])
        else:
            return sorted([
                t.name for t in self.topic_list.values() if t.source == source
            ])

    def samples(self, topic, amount=1):
        if not topic in self.topic_list:
            return None

        if amount == 0 or amount is None:
            return self.topic_list[topic].raw

        return self.topic_list[topic].raw[-amount:]

    def count(self, topic):
        if not topic in self.topic_list:
            return 0

        return len(self.topic_list[topic].raw)

    def exists(self, topic):
        return topic in self.topic_list

    def transfer(self, topic, queue, transfer_type="linear"):
        # If the topic data is not already transfered to some queue
        if not topic in self.transfers:
            self.transfers[topic] = dict()
            self.transfers[topic]['queue'] = queue
            self.transfers[topic]['lastindex'] = 0
            self.transfers[topic]['type'] = transfer_type

            self.logger.info('start transfer | {0}'.format(topic))

            # If there is already existing data under the topic
            if topic in self.topic_list:
                if transfer_type == 'indexed':
                    for key, value in self.topic_list[topic].indexes.iteritems(
                    ):
                        queue.put([key, value])
                elif transfer_type == 'linear':
                    for item in self.topic_list[topic].raw:
                        queue.put([self.transfers[topic]['lastindex'], item])
                        self.transfers[topic]['lastindex'] += 1

    def untransfer(self, topic):
        # If the topic data is already transfered to some queue
        if topic in self.transfers:
            # Remove it from the transfer list
            del self.transfers[topic]
            self.logger.info('stop transfer | {0}'.format(topic))

    def intransfer(self, topic):
        return topic in self.transfers

    def has_indexed_data(self, topic):
        return self.topic_list[topic].has_indexed_data()
Пример #13
0
class OrderTree(object):
    '''A red-black tree used to store OrderLists in price order

    The exchange will be using the OrderTree to hold bid and ask data (one OrderTree for each side).
    Keeping the information in a red black tree makes it easier/faster to detect a match.
    '''

    def __init__(self):
        self.price_map = SortedDict() # Dictionary containing price : OrderList object
        self.prices = self.price_map.keys()
        self.order_map = {} # Dictionary containing order_id : Order object
        self.volume = 0 # Contains total quantity from all Orders in tree
        self.num_orders = 0 # Contains count of Orders in tree
        self.depth = 0 # Number of different prices in tree (http://en.wikipedia.org/wiki/Order_book_(trading)#Book_depth)

    def __len__(self):
        return len(self.order_map)

    def get_price_list(self, price):
        return self.price_map[price]

    def get_order(self, order_id):
        return self.order_map[order_id]

    def create_price(self, price):
        self.depth += 1 # Add a price depth level to the tree
        new_list = OrderList()
        self.price_map[price] = new_list

    def remove_price(self, price):
        self.depth -= 1 # Remove a price depth level
        del self.price_map[price]

    def price_exists(self, price):
        return price in self.price_map

    def order_exists(self, order):
        return order in self.order_map

    def insert_order(self, quote):
        if self.order_exists(quote['order_id']):
            self.remove_order_by_id(quote['order_id'])
        self.num_orders += 1
        if quote['price'] not in self.price_map:
            self.create_price(quote['price']) # If price not in Price Map, create a node in RBtree
        order = Order(quote, self.price_map[quote['price']]) # Create an order
        self.price_map[order.price].append_order(order) # Add the order to the OrderList in Price Map
        self.order_map[order.order_id] = order
        self.volume += order.quantity

    def update_order(self, order_update):
        order = self.order_map[order_update['order_id']]
        original_quantity = order.quantity
        if order_update['price'] != order.price:
            # Price changed. Remove order and update tree.
            order_list = self.price_map[order.price]
            order_list.remove_order(order)
            if len(order_list) == 0: # If there is nothing else in the OrderList, remove the price from RBtree
                self.remove_price(order.price)
            self.insert_order(order_update)
        else:
            # Quantity changed. Price is the same.
            order.update_quantity(order_update['quantity'], order_update['timestamp'])
        self.volume += order.quantity - original_quantity

    def remove_order_by_id(self, order_id):
        self.num_orders -= 1
        order = self.order_map[order_id]
        self.volume -= order.quantity
        order.order_list.remove_order(order)
        if len(order.order_list) == 0:
            self.remove_price(order.price)
        del self.order_map[order_id]

    def max_price(self):
        if self.depth > 0:
            return self.prices[-1]
        else:
            return None

    def min_price(self):
        if self.depth > 0:
            return self.prices[0]
        else:
            return None

    def max_price_list(self):
        if self.depth > 0:
            return self.get_price_list(self.max_price())
        else:
            return None

    def min_price_list(self):
        if self.depth > 0:
            return self.get_price_list(self.min_price())
        else:
            return None
Пример #14
0
    if isPhoto(file) :
      try :
        exif = getExif(os.path.join(subdir, file))
        if not cameraIsValid(exif) :
          continue
        # get focal length and convert from rational data type to float
        focalLength = exif[FOCALLENGTH_TAG][0] / exif[FOCALLENGTH_TAG][1]
        # count every focal length occurence in dictionary
        if (focalLength in occurences) :
          occurences[focalLength] = occurences[focalLength] + 1
        else:   # find nearest
          index = occurences.bisect(focalLength)
          greater = occurences.iloc[index]
          smaller = occurences.iloc[index - 1]
          nearestFL = greater if (greater - focalLength < focalLength - smaller) else smaller
          occurences[nearestFL] = occurences[nearestFL] + 1
      except (KeyError, TypeError, IndexError) :
        # there is no focal length info in image exif data (Key/Type/IndexError)
        pass

# plot the graph
position = arange(len(focalLengths)) + .5
barh(position, occurences.values(), align='center', color='#FF0000')
yticks(position, occurences.keys())
xlabel('Occurrences')
ylabel('Focal length')
title('Focal length usage analysis')
grid(True)
show()
Пример #15
0
class DotMap(MutableMapping):

    def __init__(self, *args, **kwargs):
        self._map = SortedDict()
        if args:
            d = args[0]
            if type(d) is dict:
                for k, v in self.__call_items(d):
                    if type(v) is dict:
                        v = DotMap(v)
                    self._map[k] = v
        if kwargs:
            for k, v in self.__call_items(kwargs):
                self._map[k] = v

    @staticmethod
    def __call_items(obj):
        if hasattr(obj, 'iteritems') and ismethod(getattr(obj, 'iteritems')):
            return obj.iteritems()
        else:
            return obj.items()

    def items(self):
        return self.iteritems()

    def iteritems(self):
        return self.__call_items(self._map)

    def __iter__(self):
        return self._map.__iter__()

    def __setitem__(self, k, v):
        self._map[k] = v

    def __getitem__(self, k):
        if k not in self._map:
            # automatically extend to new DotMap
            self[k] = DotMap()
        return self._map[k]

    def __setattr__(self, k, v):
        if k == '_map':
            super(DotMap, self).__setattr__(k, v)
        else:
            self[k] = v

    def __getattr__(self, k):
        if k == '_map':
            return self._map
        else:
            return self[k]

    def __delattr__(self, key):
        return self._map.__delitem__(key)

    def __contains__(self, k):
        return self._map.__contains__(k)

    def __str__(self):
        items = []
        for k, v in self.__call_items(self._map):
            items.append('{0}={1}'.format(k, repr(v)))
        out = 'DotMap({0})'.format(', '.join(items))
        return out

    def __repr__(self):
        return str(self)

    def to_dict(self):
        d = {}
        for k, v in self.items():
            if type(v) is DotMap:
                v = v.to_dict()
            d[k] = v
        return d

    def pprint(self):
        pprint(self.to_dict())

    # proper dict subclassing
    def values(self):
        return self._map.values()

    @staticmethod
    def parse_other(other):
        if type(other) is DotMap:
            return other._map
        else:
            return other

    def __cmp__(self, other):
        other = DotMap.parse_other(other)
        return self._map.__cmp__(other)

    def __eq__(self, other):
        other = DotMap.parse_other(other)
        if not isinstance(other, dict):
            return False
        return self._map.__eq__(other)

    def __ge__(self, other):
        other = DotMap.parse_other(other)
        return self._map.__ge__(other)

    def __gt__(self, other):
        other = DotMap.parse_other(other)
        return self._map.__gt__(other)

    def __le__(self, other):
        other = DotMap.parseOther(other)
        return self._map.__le__(other)

    def __lt__(self, other):
        other = DotMap.parse_other(other)
        return self._map.__lt__(other)

    def __ne__(self, other):
        other = DotMap.parse_other(other)
        return self._map.__ne__(other)

    def __delitem__(self, key):
        return self._map.__delitem__(key)

    def __len__(self):
        return self._map.__len__()

    def copy(self):
        return self

    def get(self, key, default=None):
        return self._map.get(key, default)

    def has_key(self, key):
        return key in self._map

    def iterkeys(self):
        return self._map.iterkeys()

    def itervalues(self):
        return self._map.itervalues()

    def keys(self):
        return self._map.keys()

    def pop(self, key, default=None):
        return self._map.pop(key, default)

    def setdefault(self, key, default=None):
        return self._map.setdefault(key, default)

    def viewitems(self):
        if version_info.major == 2 and version_info.minor >= 7:
            return self._map.viewitems()
        else:
            return self._map.items()

    def viewkeys(self):
        if version_info.major == 2 and version_info.minor >= 7:
            return self._map.viewkeys()
        else:
            return self._map.keys()

    def viewvalues(self):
        if version_info.major == 2 and version_info.minor >= 7:
            return self._map.viewvalues()
        else:
            return self._map.values()

    @classmethod
    def fromkeys(cls, seq, value=None):
        d = DotMap()
        d._map = SortedDict.fromkeys(seq, value)
        return d
Пример #16
0
class Model(object):
  '''
  The model of a Stranbeest. The Model consists of a set of nodes, edges and boundary
  conditions. Each node has a unique name and a x and y position which may change
  whenever the simuation is incremented. Each node introduces two degrees of freedom.
  The edges are specified by the nodes they are connecting. The edges are the push/pull
  rods which connect the edges whith one another. An edges keeps the distances between
  two nodes constant and therefore constrains exactly one degree of freedom in the system.
  '''

  def __init__(self):
    '''
    Constructor
    '''
    self._nodes = SortedDict()
    self._edges = defaultdict(set)

  def addNode(self,name,x,y):
    if not isinstance(name,str  ): raise Exception("The 1st argument must be the node's name as str.")
    if not isinstance(x   ,float): raise Exception("The 2nd argument must be the node's x position as float.")
    if not isinstance(y   ,float): raise Exception("The 2nd argument must be the node's y position as float.")
    if name in self._nodes: raise Exception( 'There already exists a node by the name of "%(name)s"' % locals() )
    self._nodes[name] = x,y
    self.__t = 0.0
    for listener in self.onNodeAddListeners:
      listener(name,x,y)

  def addEdge(self,node1,node2):
    if node1 == node2:
      raise Exception('"node1" cannot be equal to "node2".')
    self._edges[node1].add(node2)
    self._edges[node2].add(node1)
    for listener in self.onEdgeAddListeners:
      listener( min(node1,node2), max(node1,node2) )

  def pos(self,name):
    return self._nodes[name]

  def move(self,name,x,y):
    self._nodes[name] = x,y
    for listener in self.onNodeMoveListeners:
      listener(name,x,y)

  def state(self):
    return fromiter( chain.from_iterable( self._nodes.values() ), float )

  def setState(self,state):
    for i,(x,y) in enumerate( zip(state[::2],state[1::2]) ):
      self.move(self._nodes.keys()[i],x,y)

  @property
  def t(self):
    return self.__t

  def increment(self,dt):
    v = self.v
    t0 = self.__t
    x0 = self.state()
    # https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#The_Runge.E2.80.93Kutta_method
    k0 = v(x0,           t0)
    k1 = v(x0+k0*(dt/2), t0+dt/2)
    k2 = v(x0+k1*(dt/2), t0+dt/2)
    k3 = v(x0+k2*(dt),   t0+dt)
    self.setState( x0 + dt/6 * (k0+k1+k2+k3) )
    self.__t += dt

  def v(self,x,t):
    lhs = zeros( 2*[len(x)] )
    rhs = zeros( len(x) )
    iRows = iter( range( len(x) ) )
    for start,end in self.edges():
      iStart = 2*self._nodes.index(start)
      iEnd   = 2*self._nodes.index(end)
      iRow = next(iRows)
      dx = x[iEnd+0] - x[iStart+0] 
      dy = x[iEnd+1] - x[iStart+1]
      lhs[iRow,iStart+0] = dx; lhs[iRow,iEnd+0] = -dx
      lhs[iRow,iStart+1] = dy; lhs[iRow,iEnd+1] = -dy
      rhs[iRow] = 0
    for bc in self.bcs:
      bc.addEquations(x,t,iRows,lhs,rhs)
    return linalg.solve(lhs,rhs)

  def nodes(self):
    return self._nodes.iteritems()

  def edges(self):
    for node1,neighbors in self._edges.items():
      for node2 in neighbors:
        if node1 < node2:
          yield node1,node2

  bcs = []

  onEdgeAddListeners = set() # <- FIXME should be a multiset
  onNodeAddListeners = set() # <- FIXME should be a multiset
  onNodeMoveListeners = set() # <- FIXME should be a multiset
Пример #17
0
class TestFormat(unittest.TestCase):

    '''Tests for formatting functions'''
    @classmethod
    def setUpClass(self):
        self._vocab = SortedDict({
            '1950_1959': [('w1', 1.0), ('w2', 1.0)],
            '1951_1960': [('w3', 1.0), ('w4', 1.0)],
            '1952_1961': [('w5', 1.0), ('w6', 1.0)],
            '1953_1962': [('w7', 1.0), ('w8', 1.0)]
        })
        self._links = SortedDict({
            '1950_1959': {'w1': [('w1', 0.0), ('w2', 1.0)]},
            '1951_1960': {'w3': [('w3', 0.0), ('w4', 1.0)]},
            '1952_1961': {'w5': [('w5', 0.0), ('w6', 1.0)]},
            '1953_1962': {'w7': [('w7', 0.0), ('w8', 1.0)]}
        })
        self._aggVocab = SortedDict({
            '1954': [('w1', 1.0), ('w2', 1.0)],
            '1955': [('w3', 1.0), ('w4', 1.0)],
            '1956': [('w5', 1.0), ('w6', 1.0)],
            '1957': [('w7', 1.0), ('w8', 1.0)]
        })
        self._aggPeriods = SortedDict({
            '1954': ['1950_1959'],
            '1955': ['1951_1960'],
            '1956': ['1952_1961'],
            '1957': ['1953_1962']
        })

    def testGetRangeMiddle(self):
        '''Test finding middle of range works'''
        self.assertEqual(fmt.getRangeMiddle('1951_1960'), 1955,
                         'Middle of 50s decade should be 1955')
        self.assertEqual(fmt.getRangeMiddle('1959_1968', '1962_1971'), 1965,
                         'Middle of 60s decade should be 1965')

    def testYearlyNetwork(self):
        '''Test building of yearly networks'''
        networks = fmt.yearlyNetwork(self._aggPeriods, self._aggVocab,
                                     self._vocab, self._links)

        self.assertEqual(sorted(networks.keys()), list(self._aggVocab.keys()),
                         'A network should be created for each aggregated '
                         'vocabulary')

        self.assertEqual(
            sorted(networks.keys()), list(self._aggPeriods.keys()),
            'A network should be created for each aggregation period')
        for year, net in networks.iteritems():
            self.assertEqual(sorted(net.keys()), sorted(['nodes', 'links']),
                             'Each network should contain "nodes" and "links"'
                             'but %s does not' % year)
            for node in net['nodes']:
                self.assertEqual(sorted(node.keys()),
                                 sorted(['name', 'type', 'count']),
                                 'Each node should contain "name", "type" and '
                                 '"count", but a node on %s does not' % year)
            for link in net['links']:
                self.assertEqual(sorted(link.keys()),
                                 sorted(['source', 'target', 'value']),
                                 'Each link should contain "source", "target" '
                                 'and "value", but a link on %s does not'
                                 % year)

    def testYearTuplesAsDict(self):
        '''Test converting tuple dictionary to nested dictionary'''
        dicts = fmt.yearTuplesAsDict(self._aggVocab)
        self.assertEqual(sorted(dicts.keys()), list(self._aggVocab.keys()),
                         'A dictionary should be created for each aggregated '
                         'vocabulary')
        for year, d in dicts.iteritems():
            self.assertEqual(len(d), len(self._aggVocab[year]),
                             'Dict should have same number of items as '
                             'aggregated vocabulary but %s does not' % year)

    def testWordLocationAsDict(self):
        '''Test creating word-location dictionary'''
        word = 'w1'
        loc = (0,1)
        d = fmt.wordLocationAsDict(word,loc)
        self.assertIsInstance(d, dict,' Should be a dictionary')
        self.assertEqual(sorted(d.keys()),
                         sorted(['word', 'x', 'y']),
                         'Should contain "word", "x" and "y"')
Пример #18
0
class FederationRemoteSendQueue(object):
    """A drop in replacement for FederationSender"""

    def __init__(self, hs):
        self.server_name = hs.hostname
        self.clock = hs.get_clock()
        self.notifier = hs.get_notifier()
        self.is_mine_id = hs.is_mine_id

        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
        self.presence_changed = SortedDict()  # Stream position -> list[user_id]

        # Stores the destinations we need to explicitly send presence to about a
        # given user.
        # Stream position -> (user_id, destinations)
        self.presence_destinations = SortedDict()

        self.keyed_edu = {}  # (destination, key) -> EDU
        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)

        self.edus = SortedDict()  # stream position -> Edu

        self.device_messages = SortedDict()  # stream position -> destination

        self.pos = 1
        self.pos_time = SortedDict()

        # EVERYTHING IS SAD. In particular, python only makes new scopes when
        # we make a new function, so we need to make a new function so the inner
        # lambda binds to the queue rather than to the name of the queue which
        # changes. ARGH.
        def register(name, queue):
            LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
                       "", [], lambda: len(queue))

        for queue_name in [
            "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
            "edus", "device_messages", "pos_time", "presence_destinations",
        ]:
            register(queue_name, getattr(self, queue_name))

        self.clock.looping_call(self._clear_queue, 30 * 1000)

    def _next_pos(self):
        pos = self.pos
        self.pos += 1
        self.pos_time[self.clock.time_msec()] = pos
        return pos

    def _clear_queue(self):
        """Clear the queues for anything older than N minutes"""

        FIVE_MINUTES_AGO = 5 * 60 * 1000
        now = self.clock.time_msec()

        keys = self.pos_time.keys()
        time = self.pos_time.bisect_left(now - FIVE_MINUTES_AGO)
        if not keys[:time]:
            return

        position_to_delete = max(keys[:time])
        for key in keys[:time]:
            del self.pos_time[key]

        self._clear_queue_before_pos(position_to_delete)

    def _clear_queue_before_pos(self, position_to_delete):
        """Clear all the queues from before a given position"""
        with Measure(self.clock, "send_queue._clear"):
            # Delete things out of presence maps
            keys = self.presence_changed.keys()
            i = self.presence_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_changed[key]

            user_ids = set(
                user_id
                for uids in self.presence_changed.values()
                for user_id in uids
            )

            keys = self.presence_destinations.keys()
            i = self.presence_destinations.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.presence_destinations[key]

            user_ids.update(
                user_id for user_id, _ in self.presence_destinations.values()
            )

            to_del = [
                user_id for user_id in self.presence_map if user_id not in user_ids
            ]
            for user_id in to_del:
                del self.presence_map[user_id]

            # Delete things out of keyed edus
            keys = self.keyed_edu_changed.keys()
            i = self.keyed_edu_changed.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.keyed_edu_changed[key]

            live_keys = set()
            for edu_key in self.keyed_edu_changed.values():
                live_keys.add(edu_key)

            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
            for edu_key in to_del:
                del self.keyed_edu[edu_key]

            # Delete things out of edu map
            keys = self.edus.keys()
            i = self.edus.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.edus[key]

            # Delete things out of device map
            keys = self.device_messages.keys()
            i = self.device_messages.bisect_left(position_to_delete)
            for key in keys[:i]:
                del self.device_messages[key]

    def notify_new_events(self, current_id):
        """As per FederationSender"""
        # We don't need to replicate this as it gets sent down a different
        # stream.
        pass

    def build_and_send_edu(self, destination, edu_type, content, key=None):
        """As per FederationSender"""
        if destination == self.server_name:
            logger.info("Not sending EDU to ourselves")
            return

        pos = self._next_pos()

        edu = Edu(
            origin=self.server_name,
            destination=destination,
            edu_type=edu_type,
            content=content,
        )

        if key:
            assert isinstance(key, tuple)
            self.keyed_edu[(destination, key)] = edu
            self.keyed_edu_changed[pos] = (destination, key)
        else:
            self.edus[pos] = edu

        self.notifier.on_new_replication_data()

    def send_read_receipt(self, receipt):
        """As per FederationSender

        Args:
            receipt (synapse.types.ReadReceipt):
        """
        # nothing to do here: the replication listener will handle it.
        pass

    def send_presence(self, states):
        """As per FederationSender

        Args:
            states (list(UserPresenceState))
        """
        pos = self._next_pos()

        # We only want to send presence for our own users, so lets always just
        # filter here just in case.
        local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))

        self.presence_map.update({state.user_id: state for state in local_states})
        self.presence_changed[pos] = [state.user_id for state in local_states]

        self.notifier.on_new_replication_data()

    def send_presence_to_destinations(self, states, destinations):
        """As per FederationSender

        Args:
            states (list[UserPresenceState])
            destinations (list[str])
        """
        for state in states:
            pos = self._next_pos()
            self.presence_map.update({state.user_id: state for state in states})
            self.presence_destinations[pos] = (state.user_id, destinations)

        self.notifier.on_new_replication_data()

    def send_device_messages(self, destination):
        """As per FederationSender"""
        pos = self._next_pos()
        self.device_messages[pos] = destination
        self.notifier.on_new_replication_data()

    def get_current_token(self):
        return self.pos - 1

    def federation_ack(self, token):
        self._clear_queue_before_pos(token)

    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
        """Get rows to be sent over federation between the two tokens

        Args:
            from_token (int)
            to_token(int)
            limit (int)
            federation_ack (int): Optional. The position where the worker is
                explicitly acknowledged it has handled. Allows us to drop
                data from before that point
        """
        # TODO: Handle limit.

        # To handle restarts where we wrap around
        if from_token > self.pos:
            from_token = -1

        # list of tuple(int, BaseFederationRow), where the first is the position
        # of the federation stream.
        rows = []

        # There should be only one reader, so lets delete everything its
        # acknowledged its seen.
        if federation_ack:
            self._clear_queue_before_pos(federation_ack)

        # Fetch changed presence
        i = self.presence_changed.bisect_right(from_token)
        j = self.presence_changed.bisect_right(to_token) + 1
        dest_user_ids = [
            (pos, user_id)
            for pos, user_id_list in self.presence_changed.items()[i:j]
            for user_id in user_id_list
        ]

        for (key, user_id) in dest_user_ids:
            rows.append((key, PresenceRow(
                state=self.presence_map[user_id],
            )))

        # Fetch presence to send to destinations
        i = self.presence_destinations.bisect_right(from_token)
        j = self.presence_destinations.bisect_right(to_token) + 1

        for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
            rows.append((pos, PresenceDestinationsRow(
                state=self.presence_map[user_id],
                destinations=list(dests),
            )))

        # Fetch changes keyed edus
        i = self.keyed_edu_changed.bisect_right(from_token)
        j = self.keyed_edu_changed.bisect_right(to_token) + 1
        # We purposefully clobber based on the key here, python dict comprehensions
        # always use the last value, so this will correctly point to the last
        # stream position.
        keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}

        for ((destination, edu_key), pos) in iteritems(keyed_edus):
            rows.append((pos, KeyedEduRow(
                key=edu_key,
                edu=self.keyed_edu[(destination, edu_key)],
            )))

        # Fetch changed edus
        i = self.edus.bisect_right(from_token)
        j = self.edus.bisect_right(to_token) + 1
        edus = self.edus.items()[i:j]

        for (pos, edu) in edus:
            rows.append((pos, EduRow(edu)))

        # Fetch changed device messages
        i = self.device_messages.bisect_right(from_token)
        j = self.device_messages.bisect_right(to_token) + 1
        device_messages = {v: k for k, v in self.device_messages.items()[i:j]}

        for (destination, pos) in iteritems(device_messages):
            rows.append((pos, DeviceRow(
                destination=destination,
            )))

        # Sort rows based on pos
        rows.sort()

        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
Пример #19
0
while 1 :

  choice = input('=========CHOOSE AN OPERATION=========\nEnter the operation you want to perform : \n 1. Add \n 2. Query \n 3. Delete \n 4. Exit \n ======================================\n')
  choice = choice.title()

  if choice == '2' or choice == 'Query' :

    print('\n ============= QUERY ============== \n')

    print('Your Phonebook : \n',pb.head())

    # Input keyword from user
    keyword = str(input('Input the keyword you want to search : '))

    All_Con = phonebook.keys() 
    All_Con = list(All_Con)

    res = []
    for i in All_Con : 
      if i.find(keyword)!=-1:
        res.append(i)

    if len(res)==0:
      print('No contacts matching this key exists! :/','\n')
    else : 
      print(f'All keys matching the keyword "{keyword}"\n', res)

      for i in res: 
        print(i)
Пример #20
0
class ImpulseDiGraph(ImpulseGraph):
    """Base class for directed impulse graphs.

    The ImpulseDiGraph class allows any hashable object as a node
    and can associate key/value attribute pairs with each directed edge.

    Each edge must have one integer, timestamp.

    Self-loops are allowed.
    Multiple edges between two nodes are allowed.

    Parameters
    ----------
    attr : keyword arguments, optional (default= no attributes)
        Attributes to add to graph as key=value pairs.

    Examples
    --------
    Create an empty graph structure (a "null impulse graph") with no nodes and
    no edges.

    >>> G = dnx.ImpulseDiGraph()

    G can be grown in several ways.

    **Nodes:**

    Add one node at a time:

    >>> G.add_node(1)

    Add the nodes from any container (a list, dict, set or
    even the lines from a file or the nodes from another graph).

    Add the nodes from any container (a list, dict, set)

    >>> G.add_nodes_from([2, 3])
    >>> G.add_nodes_from(range(100, 110))

    **Edges:**

    G can also be grown by adding edges. This can be considered
    the primary way to grow G, since nodes with no edge will not
    appear in G in most cases. See ``G.to_snapshot()``.

    Add one edge, with timestamp of 10.

    >>> G.add_edge(1, 2, 10)

    a list of edges,

    >>> G.add_edges_from([(1, 2, 10), (1, 3, 11)])

    If some edges connect nodes not yet in the graph, the nodes
    are added automatically. There are no errors when adding
    nodes or edges that already exist.

    **Attributes:**

    Each impulse graph, node, and edge can hold key/value attribute pairs
    in an associated attribute dictionary (the keys must be hashable).
    By default these are empty, but can be added or changed using
    add_edge, add_node.

    Keep in mind that the edge timestamp is not an attribute of the edge.

    >>> G = dnx.ImpulseDiGraph(day="Friday")
    >>> G.graph
    {'day': 'Friday'}

    Add node attributes using add_node(), add_nodes_from()

    >>> G.add_node(1, time='5pm')
    >>> G.add_nodes_from([3], time='2pm')

    Add edge attributes using add_edge(), add_edges_from().

    >>> G.add_edge(1, 2, 10, weight=4.7 )
    >>> G.add_edges_from([(3, 4, 11), (4, 5, 33)], color='red')

    **Shortcuts:**

    Here are a couple examples of available shortcuts:

    >>> 1 in G  # check if node in impulse graph during any timestamp
    True
    >>> len(G)  # number of nodes in the entire impulse graph
    5

    **Subclasses (Advanced):**
    Edges in impulse graphs are represented by tuples kept in a SortedDict
    (http://www.grantjenks.com/docs/sortedcontainers/) keyed by timestamp.

    The Graph class uses a dict-of-dict-of-dict data structure.
    The outer dict (node_dict) holds adjacency information keyed by nodes.
    The next dict (adjlist_dict) represents the adjacency information and holds
    edge data keyed by interval objects. The inner dict (edge_attr_dict) represents
    the edge data and holds edge attribute values keyed by attribute names.
    """
    def __init__(self, **attr):
        """Initialize an impulse graph with edges, name, or graph attributes.

        Parameters
        ----------
        attr : keyword arguments, optional (default= no attributes)
            Attributes to add to graph as key=value pairs.

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G = dnx.ImpulseDiGraph(name='my graph')
        >>> G.graph
        {'name': 'my graph'}
        """

        self.tree = SortedDict()
        self.graph = {}  # dictionary for graph attributes
        self._node = {}
        self._pred = {}  # out
        self._succ = {}  # in
        self.edgeid = 0

        self.graph.update(attr)

    def add_edge(self, u, v, t, **attr):
        """Add an edge between u and v, at t.

        The nodes u and v will be automatically added if they are
        not already in the impulse graph.

        Edge attributes can be specified with keywords or by directly
        accessing the edge's attribute dictionary. See examples below.

        Parameters
        ----------
        u, v : nodes
            Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        t : timestamp
            Timestamps can be, for example, strings or numbers.
            Timestamps must be hashable (and not None) Python objects.
        attr : keyword arguments, optional
            Edge data (or labels or objects) can be assigned using
            keyword arguments.

        See Also
        --------
        add_edges_from : add a collection of edges

        Notes
        -----
        Adding an edge that already exists updates the edge data.

        Timestamps must be the same type across all edges in the impulse graph.
        Also, to create snapshots, timestamps must be integers.

        Many NetworkX algorithms designed for weighted graphs use
        an edge attribute (by default `weight`) to hold a numerical value.

        Examples
        --------
        The following all add the edge e=(1, 2, 3, 10) to graph G:

        >>> G = dnx.ImpulseDiGraph()
        >>> e = (1, 2, 10)
        >>> G.add_edge(1, 2, 10)           # explicit two-node form with timestamp
        >>> G.add_edge(*e)             # single edge as tuple of two nodes and timestamp
        >>> G.add_edges_from([(1, 2, 10)])  # add edges from iterable container

        Associate data to edges using keywords:

        >>> G.add_edge(1, 2, 10 weight=3)
        >>> G.add_edge(1, 3, 9, weight=7, capacity=15, length=342.7)
        """

        self.tree.setdefault(t, set()).add((u, v))

        self._node.setdefault(u, {})
        self._node.setdefault(v, {})
        self._pred.setdefault(u, {}).setdefault(v, {})[(u, v, t)] = attr
        self._succ.setdefault(v, {}).setdefault(u, {})[(u, v, t)] = attr

    def add_edges_from(self, ebunch_to_add, **attr):
        """Add all the edges in ebunch_to_add.

        Parameters
        ----------
        ebunch_to_add : container of edges
            Each edge given in the container will be added to the
            impulse graph. The edges must be given as as 3-tuples (u, v, t).
            Timestamp must be orderable and the same type across all edges.
        attr : keyword arguments, optional
            Edge data (or labels or objects) can be assigned using
            keyword arguments.

        See Also
        --------
        add_edge : add a single edge

        Notes
        -----
        Adding the same edge (with the same timestamp) twice has no effect
        but any edge data will be updated when each duplicate edge is added.

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11)]) # using a list of edge tuples

        Associate data to edges

        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11)], weight=3)
        >>> G.add_edges_from([(3, 4, 19), (1, 4, 3)], label='WN2898')
        """

        for e in ebunch_to_add:
            if len(e) != 3:
                raise NetworkXError(
                    "Edge tuple {0} must be a 3-tuple.".format(e))
            self.add_edge(e[0], e[1], e[2], **attr)

    def has_edge(self, u, v, begin=None, end=None, inclusive=(True, True)):
        """Return True if there exists an edge between u and v
        in the impulse graph, during the given interval.

        Parameters
        ----------
        u, v : nodes
            Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        begin : int or float, optional (default= beginning of the entire impulse graph)
        end : int or float, optional (default= end of the entire impulse graph)
            Must be bigger than or equal begin.
        inclusive: 2-tuple boolean that determines inclusivity of begin and end

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11)])
        >>> G.has_edge(1, 2)
        True
        >>> G.has_edge(1, 2, begin=2)
        True
        >>> G.has_edge(2, 4, begin=12)
        False
        """

        if u not in self._pred or v not in self._pred[u]:
            return False

        if begin is None and end is None:
            return True

        if begin and end and begin > end:
            raise NetworkXError(
                "IntervalGraph: interval end must be bigger than or equal to begin: "
                "begin: {}, end: {}.".format(begin, end))

        for iv in self._pred[u][v]:
            if self.__in_interval(iv[2], begin, end, inclusive=inclusive):
                return True
        return False

    def edges(self,
              u=None,
              v=None,
              begin=None,
              end=None,
              inclusive=(True, True),
              data=False,
              default=None):
        """Returns a list of tuples of the ImpulseDiGraph edges.

        All edges which are present within the given interval.

        All parameters are optional. `u` and `v` can be thought of as constraints.
        If no node is defined, all edges within the interval are returned.
        If one node is defined, all edges which have that node as one end,
        will be returned, and finally if both nodes are defined then all
        edges between the two nodes are returned.

        Parameters
        ----------
        u, v : nodes, optional (default=None)
            Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
            If the node does not exist in the graph, a key error is raised.
        begin: int or float, optional  (default= beginning of the entire impulse graph)
        end: int or float, optional  (default= end of the entire impulse graph)
            Must be bigger than or equal to begin.
        inclusive: 2-tuple boolean that determines inclusivity of begin and end
        data : string or bool, optional (default=False)
            If True, return 2-tuple (Edge Tuple, dict of attributes).
            If False, return just the Edge Tuples.
            If string (name of the attribute), return 2-tuple (Edge Tuple, attribute value).
        default : value, optional (default=None)
            Default Value to be used for edges that don't have the requested attribute.
            Only relevant if `data` is a string (name of an attribute).

        Returns
        -------
        List of Edge Tuples
            An edge tuple has the following format: (u, v, edge_id, timestamp)

            When called, if `data` is False, a list of edge tuples.
            If `data` is True, a list of 2-tuples: (Edge Tuple, dict of attribute(s) with values),
            If `data` is a string, a list of 2-tuples (Edge Tuple, attribute value).

        Examples
        --------
        To get a list of all edges:

        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11), (6, 4, 19), (2, 4, 15)])
        >>> G.edges()
        [(1, 2, 10), (2, 4, 11), (2, 4, 15), (6, 4, 19)]

        To get edges which appear in a specific interval:

        >>> G.edges(begin=10)
        [(1, 2, 10), (2, 4, 11), (2, 4, 15), (6, 4, 19)]
        >>> G.edges(end=11)
        [(1, 2, 10), (2, 4, 11)]
        >>> G.edges(begin=11, end=15)
        [(2, 4, 11), (2, 4, 15)]

        To get edges with either of the two nodes being defined:

        >>> G.edges(u=2)
        [(2, 4, 11), (2, 4, 15)]
        >>> G.edges(u=2, begin=11)
        [(2, 4, 11), (2, 4, 15)]
        >>> G.edges(u=2, v=4, end=11)
        [(2, 4, 11)]
        >>> G.edges(u=1, v=6)
        []

        To get a list of edges with data:

        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edge(1, 3, 4, weight=8, height=18)
        >>> G.add_edge(1, 2, 10, weight=10)
        >>> G.add_edge(2, 6, 10)
        >>> G.edges(data="weight")
        [((1, 3, 4), 8), ((1, 2, 10), 10), ((2, 6, 10), None)]
        >>> G.edges(data="weight", default=5)
        [((1, 3, 4), 8), ((1, 2, 10), 10), ((2, 6, 10), 5)]
        >>> G.edges(data=True)
        [((1, 3, 4), {'weight': 8, 'height': 18}), ((1, 2, 10), {'weight': 10}), ((2, 6, 10), {})]
        >>> G.edges(u=1, begin=2, end=9, data="weight")
        [((1, 3, 4), 8)]
        """

        if begin is None:
            inclusive = (True, inclusive[1])
        if end is None:
            inclusive = (inclusive[0], True)

        if u is None and v is None:
            if begin is not None and end is not None and begin > end:
                raise NetworkXError(
                    "IntervalGraph: interval end must be bigger than or equal to begin: "
                    "begin: {}, end: {}.".format(begin, end))
            iedges = [iv for iv in self.__search_tree(begin, end, inclusive)]

        else:
            # Node filtering
            if u is not None and v is not None:
                if u not in self._pred:
                    return []
                if v not in self._pred[u]:
                    return []
                iedges = self._pred[u][v]

            elif u is not None:
                if u not in self._pred:
                    return []
                iedges = [iv for v in self._pred[u] for iv in self._pred[u][v]]
            else:
                if v not in self._succ:
                    return []
                iedges = [iv for u in self._succ[v] for iv in self._succ[v][u]]

            # Interval filtering
            if begin is not None and end is not None and begin > end:
                raise NetworkXError(
                    "IntervalGraph: interval end must be bigger than or equal to begin: "
                    "begin: {}, end: {}.".format(begin, end))
            iedges = [
                iv for iv in iedges
                if self.__in_interval(iv[2], begin, end, inclusive=inclusive)
            ]

        if data is False:
            return [edge for edge in iedges]

        if data is True:
            return [(edge, self._pred[edge[0]][edge[1]][edge])
                    for edge in iedges]
        return [(edge, self._pred[edge[0]][edge[1]][edge][data]) if data
                in self._pred[edge[0]][edge[1]][edge] else (edge, default)
                for edge in iedges]

    def remove_edge(self, u, v, begin=None, end=None, inclusive=(True, True)):
        """Remove the edge between u and v in the impulse graph,
        during the given interval.

        Quiet if the specified edge is not present.

        Parameters
        ----------
        u, v : nodes
            Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        begin : int or float, optional (default= beginning of the entire impulse graph)
        end : int or float, optional (default= end of the entire impulse graph + 1)
            Must be bigger than or equal to begin.
        inclusive: 2-tuple boolean that determines inclusivity of begin and end

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11), (6, 4, 9), (1, 2, 15)])
        >>> G.remove_edge(1, 2)
        >>> G.has_edge(1, 2)
        False

        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11), (6, 4, 9), (1, 2, 15)])
        >>> G.remove_edge(1, 2, begin=2, end=11)
        >>> G.has_edge(1, 2, begin=2, end=11)
        False
        >>> G.has_edge(1, 2)
        True
        """

        if u not in self._pred or v not in self._pred[u]:
            return

        iedges_to_remove = []

        # remove every edge between u and v
        if begin is None and end is None:
            for iv in self._pred[u][v]:
                iedges_to_remove.append(iv)

        else:
            for iv in self._pred[u][v]:
                if self.__in_interval(iv[2], begin, end):
                    iedges_to_remove.append(iv)

        # removing found iedges
        for edge in iedges_to_remove:
            self.__remove_iedge(edge)

        # clean up empty dictionaries
        if len(self._pred[u][v]) == 0:
            self._pred[u].pop(v, None)
        if len(self._succ[v][u]) == 0:
            self._succ[v].pop(u, None)
        if len(self._pred[u]) == 0:
            self._pred.pop(u, None)
        if len(self._succ[v]) == 0:
            self._succ.pop(v, None)

    def degree(self,
               node=None,
               begin=None,
               end=None,
               delta=False,
               inclusive=(True, True)):
        """Return the sum of in and out degree of a specified node between time begin and end.

        Parameters
        ----------
        node : Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        begin : int or float, optional (default= beginning of the entire impulse graph)
            Inclusive beginning time of the edge appearing in the impulse graph.
        end : int or float, optional (default= end of the entire impulse graph)
            Non-inclusive ending time of the edge appearing in the impulse graph.
        delta : boolean, optional (default= False)
            Returns list of 2-tuples, first element is the timestamp, second is the node of changing degree.
        inclusive : 2-tuple boolean that determines inclusivity of begin and end

        Returns
        -------
        Integer value of degree of specified node.

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edge(1, 2, 3)
        >>> G.add_edge(2, 3, 8)
        >>> G.degree(2)
        2
        >>> G.degree(2, 4)
        1
        >>> G.degree(2, end=8)
        2
        >>> G.degree()
        1.33333
        >>> G.degree(2, delta=True)
        [(3, 1), (8, 1)]
        """
        # no specified node, return mean degree
        if node == None:
            n = 0
            l = 0
            for node in self.nodes(begin=begin, end=end, inclusive=inclusive):
                n += 1
                l += self.degree(node,
                                 begin=begin,
                                 end=end,
                                 inclusive=inclusive)
            return l / n

        # specified node, no degree_change, return degree
        if delta == False:
            return len(self.edges(u=node, begin=begin, end=end, inclusive=inclusive)) + \
                   len(self.edges(v=node, begin=begin, end=end, inclusive=inclusive))

        # delta == True, return list of changes
        if begin == None:
            begin = list(self.tree.keys())[0]
        if end == None:
            end = list(self.tree.keys())[-1]

        d = {}
        output = []

        # for each edge determine if the begin and/or end value is in specified time period
        for edge in self.edges(u=node,
                               begin=begin,
                               end=end,
                               inclusive=(True, True)):
            d.setdefault(edge[2], []).append((edge[0], edge[1]))
        for edge in self.edges(v=node,
                               begin=begin,
                               end=end,
                               inclusive=(True, True)):
            d.setdefault(edge[2], []).append((edge[0], edge[1]))

        # for each time in Dict add to output list the len of each value
        for time in d:
            output.append((time, len(d[time])))

        return sorted(output)

    def in_degree(self,
                  node=None,
                  begin=None,
                  end=None,
                  delta=False,
                  inclusive=(True, True)):
        """Return the in-degree of a specified node between time begin and end.

        Parameters
        ----------
        node : Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        begin : int or float, optional (default= beginning of the entire impulse graph)
            Inclusive beginning time of the edge appearing in the impulse graph.
        end : int or float, optional (default= end of the entire impulse graph)
            Non-inclusive ending time of the edge appearing in the impulse graph.
        delta : boolean, optional (default= False)
            Returns list of 2-tuples, first element is the timestamp, second is the node of changing degree.
        inclusive : 2-tuple boolean that determines inclusivity of begin and end

        Returns
        -------
        Integer value of in-degree of specified node.

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edge(1, 2, 3)
        >>> G.add_edge(2, 3, 8)
        >>> G.in_degree(2)
        1
        >>> G.in_degree(2, 4)
        0
        >>> G.in_degree(2, end=8)
        1
        >>> G.in_degree()
        0.66666
        >>> G.in_degree(2, delta=True)
        [(3, 1)]
        """
        # no specified node, return mean degree
        if node == None:
            n = 0
            l = 0
            for node in self.nodes(begin=begin, end=end, inclusive=inclusive):
                n += 1
                l += self.in_degree(node,
                                    begin=begin,
                                    end=end,
                                    inclusive=inclusive)
            return l / n

        # specified node, no degree_change, return degree
        if delta == False:
            return len(
                self.edges(v=node, begin=begin, end=end, inclusive=inclusive))

        # delta == True, return list of changes
        if begin == None:
            begin = list(self.tree.keys())[0]
        if end == None:
            end = list(self.tree.keys())[-1]

        d = {}
        output = []

        # for each edge determine if the begin and/or end value is in specified time period
        for edge in self.edges(v=node,
                               begin=begin,
                               end=end,
                               inclusive=(True, True)):
            d.setdefault(edge[2], []).append((edge[0], edge[1]))

        # for each time in Dict add to output list the len of each value
        for time in d:
            output.append((time, len(d[time])))

        return output

    def out_degree(self,
                   node=None,
                   begin=None,
                   end=None,
                   delta=False,
                   inclusive=(True, True)):
        """Return the out-degree of a specified node between time begin and end.

        Parameters
        ----------
        node : Nodes can be, for example, strings or numbers.
            Nodes must be hashable (and not None) Python objects.
        begin : int or float, optional (default= beginning of the entire impulse graph)
            Inclusive beginning time of the edge appearing in the impulse graph.
        end : int or float, optional (default= end of the entire impulse graph)
            Non-inclusive ending time of the edge appearing in the impulse graph.
        delta : boolean, optional (default= False)
            Returns list of 2-tuples, first element is the timestamp, second is the node of changing degree.
        inclusive : 2-tuple boolean that determines inclusivity of begin and end

        Returns
        -------
        Integer value of out-degree of specified node.

        Examples
        --------
        >>> G = dnx.ImpulseDiGraph()
        >>> G.add_edge(1, 2, 3)
        >>> G.add_edge(2, 3, 8)
        >>> G.out_degree(2)
        1
        >>> G.out_degree(2, 2)
        1
        >>> G.out_degree(2, end=8)
        1
        >>> G.out_degree()
        0.66666
        >>> G.out_degree(2, delta=True)
        [(8, 1)]
        """
        # no specified node, return mean degree
        if node == None:
            n = 0
            l = 0
            for node in self.nodes(begin=begin, end=end, inclusive=inclusive):
                n += 1
                l += self.in_degree(node,
                                    begin=begin,
                                    end=end,
                                    inclusive=inclusive)
            return l / n

        # specified node, no degree_change, return degree
        if delta == False:
            return len(
                self.edges(u=node, begin=begin, end=end, inclusive=inclusive))

        # delta == True, return list of changes
        if begin == None:
            begin = list(self.tree.keys())[0]
        if end == None:
            end = list(self.tree.keys())[-1]

        d = {}
        output = []

        # for each edge determine if the begin and/or end value is in specified time period
        for edge in self.edges(u=node,
                               begin=begin,
                               end=end,
                               inclusive=(True, True)):
            d.setdefault(edge[2], []).append((edge[0], edge[1]))

        # for each time in Dict add to output list the len of each value
        for time in d:
            output.append((time, len(d[time])))

        return output

    def to_networkx_graph(self,
                          begin=None,
                          end=None,
                          inclusive=(True, False),
                          multigraph=False,
                          edge_data=False,
                          edge_timestamp_data=False,
                          node_data=False):
        """Return a networkx Graph or MultiGraph which includes all the nodes and
        edges which have timestamps within the given interval.

        Wrapper function for ImpulseGraph.to_subgraph. Refer to ImpulseGraph.to_subgraph for full description.
        """
        return self.to_subgraph(begin=begin,
                                end=end,
                                inclusive=inclusive,
                                multigraph=multigraph,
                                edge_data=edge_data,
                                edge_timestamp_data=edge_timestamp_data,
                                node_data=node_data)

    def to_subgraph(self,
                    begin,
                    end,
                    inclusive=(True, False),
                    multigraph=False,
                    edge_data=False,
                    edge_timestamp_data=False,
                    node_data=False):
        """Return a networkx Graph or MultiGraph which includes all the nodes and
        edges which have timestamps within the given interval.

        Parameters
        ----------
        begin: int or float
        end: int or float
            Must be bigger than or equal to begin.
        inclusive: 2-tuple boolean that determines inclusivity of begin and end
        multigraph: bool, optional (default= False)
            If True, a networkx MultiGraph will be returned. If False, networkx Graph.
        edge_data: bool, optional (default= False)
            If True, edges will keep their attributes.
        edge_timestamp_data: bool, optional (default= False)
            If True, each edge's attribute will also include its timestamp data.
            If `edge_data= True` and there already exist edge attributes named timestamp
            it will be overwritten.
        node_data : bool, optional (default= False)
            if True, each node's attributes will be included.

        See Also
        --------
        to_snapshots : divide the impulse graph to snapshots

        Notes
        -----
        If multigraph= False, and edge_data=True or edge_interval_data=True,
        in case there are multiple edges, only one will show with one of the edge's attributes.

        Note: nodes with no edges will not appear in any subgraph.

        Examples
        --------
        >>> G = dnx.ImpulseGraph()
        >>> G.add_edges_from([(1, 2, 10), (2, 4, 11), (6, 4, 19), (2, 4, 15)])
        >>> H = G.to_subgraph(4, 12)
        >>> type(H)
        <class 'networkx.classes.graph.DiGraph'>
        >>> list(H.edges(data=True))
        [(1, 2, {}), (2, 4, {})]

        >>> H = G.to_subgraph(10, 12, edge_timestamp_data=True)
        >>> type(H)
        <class 'networkx.classes.graph.DiGraph'>
        >>> list(H.edges(data=True))
        [(1, 2, {'timestamp': 10}), (2, 4, {'timestamp': 11})]

        >>> M = G.to_subgraph(4, 12, multigraph=True, edge_timestamp_data=True)
        >>> type(M)
        <class 'networkx.classes.multigraph.MultiDiGraph'>
        >>> list(M.edges(data=True))
        [(1, 2, {'timestamp': 10}), (2, 4, {'timestamp': 11})]
        """
        iedges = self.__search_tree(begin, end, inclusive=inclusive)

        if multigraph:
            G = MultiDiGraph()
        else:
            G = DiGraph()

        if edge_data and edge_timestamp_data:
            G.add_edges_from((iedge[0], iedge[1],
                              dict(self._pred[iedge[0]][iedge[1]][iedge],
                                   timestamp=iedge[3])) for iedge in iedges)
        elif edge_data:
            G.add_edges_from(
                (iedge[0], iedge[1], self._pred[iedge[0]][iedge[1]][iedge])
                for iedge in iedges)
        elif edge_timestamp_data:
            G.add_edges_from((iedge[0], iedge[1], {
                'timestamp': iedge[3]
            }) for iedge in iedges)
        else:

            G.add_edges_from((iedge[0], iedge[1]) for iedge in iedges)

        if node_data:
            G.add_nodes_from((n, self._node[n].copy()) for n in G.nodes)

        return G

    def __remove_iedge(self, iedge):
        """Remove the impulse edge from the impulse graph.

        Quiet if the specified edge is not present.

        Parameters
        ----------
        iedge : Edge Tuple (u,v,eid,t)
            Edge to be removed.
        """

        try:
            self.tree[iedge[2]].remove((iedge[0], iedge[1]))
            del self._pred[iedge[0]][iedge[1]][iedge]
            del self._succ[iedge[1]][iedge[0]][iedge]
        except:
            return

    def __validate_interval(self, begin=None, end=None):
        """Returns validated begin and end.
        Raises an exception if begin is larger than end.

        Parameters
        ----------
        begin : int or float, optional
        end : int or float, optional
        """

        if (begin is not None and end is not None) and begin > end:
            raise NetworkXError(
                "ImpulseDiGraph: interval end must be bigger than or equal to begin: "
                "begin: {}, end: {}.".format(begin, end))

        return begin, end

    def __search_tree(self, begin=None, end=None, inclusive=(True, True)):
        """if begin and end are equal performs a point search on the tree,
        otherwise an interval search is performed.

       Parameters
       ----------
       begin: int or float, optional  (default= beginning of the entire impulse graph)
       end: int or float, optional  (default= end of the entire impulse graph)
            Must be bigger than or equal begin.
       inclusive: 2-tuple boolean that determines inclusivity of begin and end
       """
        begin, end = self.__validate_interval(begin, end)

        if begin is not None and begin == end and begin in self.tree:
            for edge in self.tree[begin]:
                yield (*edge, begin)

        for t in self.tree.irange(begin, end, inclusive=inclusive):
            for edge in self.tree[t]:
                yield (*edge, t)

    def __in_interval(self, t, begin, end, inclusive=(True, True)):
        """
        Parameters
        ----------
        t: int or float, timestamp
        begin: int or float
            Beginning time of Interval.
        end: int or float
            Ending time of Interval.
            Must be bigger than or equal begin.
        inclusive: 2-tuple boolean that determines inclusivity of begin and end

        Returns
        -------
        Returns True if t is in the interval (begin,end). Otherwise False.
        """
        if begin is None:
            begin = float('-inf')
        if end is None:
            end = float('inf')

        if inclusive == (True, True):
            return begin <= t <= end
        if inclusive == (True, False):
            return begin <= t < end
        if inclusive == (False, True):
            return begin < t <= end
        if inclusive == (False, False):
            return begin < t < end

    @staticmethod
    def load_from_txt(path,
                      delimiter=" ",
                      nodetype=int,
                      timestamptype=float,
                      order=('u', 'v', 't'),
                      comments="#"):
        """Read impulse graph in from path.
           Timestamps must be integers or floats.
           Nodes can be any hashable objects.
           Edge Attributes can be assigned with in the following format: Key=Value

        Parameters
        ----------
        path : string or file
           Filename to read.

        nodetype : Python type, optional (default= int)
           Convert nodes to this type.

        timestamptype : Python type, optional (default= float)
        Convert timestamp to this type.
        This must be an orderable type, ideally int or float. Other orderable types have not been fully tested.

        order : Python 3-tuple, optional (default= ('u', 'v', 't'))
        This must be a 3-tuple containing strings 'u', 'v', and 't'. 'u' specifies the starting node, 'v' the ending node, and 't' the timestamp.

        comments : string, optional
           Marker for comment lines

        delimiter : string, optional
           Separator for node labels.  The default is whitespace. Cannot be =.

        Returns
        -------
        G: ImpulseGraph
            The graph corresponding to the lines in edge list.

        Examples
        --------
        >>> G=dnx.ImpulseGraph.load_from_txt("my_dygraph.txt")

        The optional nodetype is a function to convert node strings to nodetype.

        For example

        >>> G=dnx.ImpulseGraph.load_from_txt("my_dygraph.txt", nodetype=int)

        will attempt to convert all nodes to integer type.

        Since nodes must be hashable, the function nodetype must return hashable
        types (e.g. int, float, str, frozenset - or tuples of those, etc.)
        """

        G = ImpulseDiGraph()

        if delimiter == '=':
            raise ValueError("Delimiter cannot be =.")

        if len(
                order
        ) != 3 or 'u' not in order or 'v' not in order or 't' not in order:
            raise ValueError(
                "Order must be a 3-tuple containing strings 'u', 'v', and 't'."
            )

        with open(path, 'r') as file:
            for line in file:
                p = line.find(comments)
                if p >= 0:
                    line = line[:p]
                if not len(line):
                    continue

                line = re.split(delimiter + '+', line.strip())

                u = line[order.index('u')]
                v = line[order.index('v')]
                t = line[order.index('t')]

                edgedata = {}
                for data in line[3:]:
                    key, value = data.split('=')

                    try:
                        value = float(value)
                    except:
                        pass
                    edgedata[key] = value

                if nodetype is not int:
                    try:
                        u = nodetype(u)
                        v = nodetype(v)
                    except:
                        raise TypeError(
                            "Failed to convert node to {0}".format(nodetype))
                else:
                    try:
                        u = int(u)
                        v = int(v)
                    except:
                        pass

                try:
                    t = timestamptype(t)
                except:
                    raise TypeError(
                        "Failed to convert interval time to {}".format(
                            timestamptype))

                G.add_edge(u, v, t, **edgedata)

        return G
Пример #21
0
writer = skvideo.io.FFmpegWriter(fullPathtoOutputVideo, outputdict={'-vcodec': 'libx264', '-b': '750100000'})

while(not allFramesAreDone):

    # wait for the output channel to be full
    ConfigurationForVideoSegmentation.barrier.wait()
    # Crossing this barrier means that all workers have put their results into the output queue

    # Extract all annotated frames in increasing order. Use SortedDict for the purpose.
    s = SortedDict()

    while(not ConfigurationForVideoSegmentation.outputChannel.empty()):
        s.update(ConfigurationForVideoSegmentation.outputChannel.get())

    # Sorted container sorts by keys, and keys are frame numbers. So, we can just reverse and pop.
    for key in list(s.keys()):

        # Count the number of sentinel objects encountered
        if(key == -1):

            nSentinelObjectsEncountered = nSentinelObjectsEncountered + 1

            # Not all workers may have seen the sentinel object yet.

            if(nSentinelObjectsEncountered == ConfigurationForVideoSegmentation.nProcesses -1):
                allFramesAreDone = True
                break
        else:
            image = s[key]
            writer.writeFrame(image)
            del(s[key])
Пример #22
0
start2 = timer()
print('join prep done, time12 = ', start2 - start1)

join_resolution = m2_resolution
join = SortedDict()

ktr = 0
print('add goes b5 to join')
for k in range(len(g_idx_valid[0])):
    id = g_idx_valid[0][k]
    jk = gd.spatial_clear_to_resolution(
        gd.spatial_coerce_resolution(goes_b5_indices[id], join_resolution))
    # print('id,jk: ',id,hex16(jk))
    # print('keys:  ',[hex16(i) for i in join.keys()])
    if jk not in join.keys():
        join[jk] = join_value()
    join[jk].add('goes_b5', id)
    ktr = ktr + 1
    # if ktr > 10:
    #     break
    #     # exit();

start2a = timer()
print('add goes b5 to join, time = ', start2a - start2)

print('add MERRA-2 to join')
for k in range(len(m2_indices)):
    jk = gd.spatial_clear_to_resolution(m2_indices[k])
    if jk not in join.keys():
        join[jk] = join_value()
Пример #23
0
class KeyedRegion(object):
    """
    KeyedRegion keeps a mapping between stack offsets and all objects covering that offset. It assumes no variable in
    this region overlap with another variable in this region.

    Registers and function frames can all be viewed as a keyed region.
    """
    def __init__(self, tree=None):
        self._storage = SortedDict() if tree is None else tree

    def _get_container(self, offset):
        try:
            base_offset = next(self._storage.irange(maximum=offset, reverse=True))
        except StopIteration:
            return offset, None
        else:
            container = self._storage[base_offset]
            if container.includes(offset):
                return base_offset, container
            return offset, None

    def __contains__(self, offset):
        """
        Test if there is at least one varaible covering the given offset.

        :param offset:
        :return:
        """

        return self._get_container(offset)[1] is not None

    def __len__(self):
        return len(self._storage)

    def __iter__(self):
        return iter(self._storage.values())

    def __eq__(self, other):
        if set(self._storage.keys()) != set(other._storage.keys()):
            return False

        for k, v in self._storage.items():
            if v != other._storage[k]:
                return False

        return True

    def copy(self):
        if not self._storage:
            return KeyedRegion()

        kr = KeyedRegion()
        for key, ro in self._storage.items():
            kr._storage[key] = ro.copy()
        return kr

    def merge(self, other, make_phi_func=None):
        """
        Merge another KeyedRegion into this KeyedRegion.

        :param KeyedRegion other: The other instance to merge with.
        :return: None
        """

        # TODO: is the current solution not optimal enough?
        for _, item in other._storage.items():  # type: RegionObject
            for loc_and_var in item.stored_objects:
                self.__store(loc_and_var, overwrite=False, make_phi_func=make_phi_func)

        return self

    def dbg_repr(self):
        """
        Get a debugging representation of this keyed region.
        :return: A string of debugging output.
        """
        keys = self._storage.keys()
        offset_to_vars = { }

        for key in sorted(keys):
            ro = self._storage[key]
            variables = [ obj.obj for obj in ro.stored_objects ]
            offset_to_vars[key] = variables

        s = [ ]
        for offset, variables in offset_to_vars.items():
            s.append("Offset %#x: %s" % (offset, variables))
        return "\n".join(s)

    def add_variable(self, start, variable):
        """
        Add a variable to this region at the given offset.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.add_object(start, variable, size)

    def add_object(self, start, obj, object_size):
        """
        Add/Store an object to this region at the given offset.

        :param start:
        :param obj:
        :param int object_size: Size of the object
        :return:
        """

        self._store(start, obj, object_size, overwrite=False)

    def set_variable(self, start, variable):
        """
        Add a variable to this region at the given offset, and remove all other variables that are fully covered by
        this variable.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.set_object(start, variable, size)

    def set_object(self, start, obj, object_size):
        """
        Add an object to this region at the given offset, and remove all other objects that are fully covered by this
        object.

        :param start:
        :param obj:
        :param object_size:
        :return:
        """

        self._store(start, obj, object_size, overwrite=True)

    def get_base_addr(self, addr):
        """
        Get the base offset (the key we are using to index objects covering the given offset) of a specific offset.

        :param int addr:
        :return:
        :rtype:  int or None
        """

        base_addr, container = self._get_container(addr)
        if container is None:
            return None
        else:
            return base_addr

    def get_variables_by_offset(self, start):
        """
        Find variables covering the given region offset.

        :param int start:
        :return: A list of stack variables.
        :rtype:  set
        """

        _, container = self._get_container(start)
        if container is None:
            return []
        else:
            return container.internal_objects

    def get_objects_by_offset(self, start):
        """
        Find objects covering the given region offset.

        :param start:
        :return:
        """

        _, container = self._get_container(start)
        if container is None:
            return set()
        else:
            return container.internal_objects

    #
    # Private methods
    #

    def _store(self, start, obj, size, overwrite=False):
        """
        Store a variable into the storage.

        :param int start: The beginning address of the variable.
        :param obj: The object to store.
        :param int size: Size of the object to store.
        :param bool overwrite: Whether existing objects should be overwritten or not.
        :return: None
        """

        stored_object = StoredObject(start, obj, size)
        self.__store(stored_object, overwrite=overwrite)

    def __store(self, stored_object, overwrite=False, make_phi_func=None):
        """
        Store a variable into the storage.

        :param StoredObject stored_object: The descriptor describing start address and the variable.
        :param bool overwrite: Whether existing objects should be overwritten or not.
        :return: None
        """

        start = stored_object.start
        object_size = stored_object.size
        end = start + object_size

        # region items in the middle
        overlapping_items = list(self._storage.irange(start, end-1))

        # is there a region item that begins before the start and overlaps with this variable?
        floor_key, floor_item = self._get_container(start)
        if floor_item is not None and floor_key not in overlapping_items:
                # insert it into the beginning
                overlapping_items.insert(0, floor_key)

        # scan through the entire list of region items, split existing regions and insert new regions as needed
        to_update = {start: RegionObject(start, object_size, {stored_object})}
        last_end = start

        for floor_key in overlapping_items:
            item = self._storage[floor_key]
            if item.start < start:
                # we need to break this item into two
                a, b = item.split(start)
                if overwrite:
                    b.set_object(stored_object)
                else:
                    self._add_object_or_make_phi(b, stored_object, make_phi_func=make_phi_func)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            elif item.start > last_end:
                # there is a gap between the last item and the current item
                # fill in the gap
                new_item = RegionObject(last_end, item.start - last_end, {stored_object})
                to_update[new_item.start] = new_item
                last_end = new_item.end
            elif item.end > end:
                # we need to split this item into two
                a, b = item.split(end)
                if overwrite:
                    a.set_object(stored_object)
                else:
                    self._add_object_or_make_phi(a, stored_object, make_phi_func=make_phi_func)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            else:
                if overwrite:
                    item.set_object(stored_object)
                else:
                    self._add_object_or_make_phi(item, stored_object, make_phi_func=make_phi_func)
                to_update[item.start] = item

        self._storage.update(to_update)

    def _is_overlapping(self, start, variable):

        if variable.size is not None:
            # make sure this variable does not overlap with any other variable
            end = start + variable.size
            try:
                prev_offset = next(self._storage.irange(maximum=end-1, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                if start <= prev_offset < end:
                    return True
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if start < prev_offset + prev_item_size < end:
                    return True
        else:
            try:
                prev_offset = next(self._storage.irange(maximum=start, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if prev_offset <= start < prev_offset + prev_item_size:
                    return True

        return False

    def _add_object_or_make_phi(self, item, stored_object, make_phi_func=None):  #pylint:disable=no-self-use
        if not make_phi_func or len({stored_object.obj} | item.internal_objects) == 1:
            item.add_object(stored_object)
        else:
            # make a phi node
            item.set_object(StoredObject(stored_object.start,
                                         make_phi_func(stored_object.obj, *item.internal_objects),
                                         stored_object.size,
                                         )
                            )
Пример #24
0
class Book(object):
    '''| TODO
    | Keep track of the active orders. It is constructed by using unordered_map, map, and vector data-structures.
    | Unordered map is used to keep pointers to all active orders. In this implementation, it is used to check whether an order already
    | exists in the book. Sorted maps are used to represent the bid and ask depths of the book using the price as a key. For efficiency,
    | the price is represented as (scaled) uint64_t. The insert operation inserts the element at the correct place implementing the
    | price priority in the book. Each element of the maps is a price level (see Level.hpp). Note that the best bid is the last element, i.e.,
    | the last price level, of the bid map; best ask is the first element (price level) of the ask map.
    |________'''
    def __init__(self):
        self.bid = SortedDict(
            neg)  # Key is price as int, value is Level (in descending order)
        self.ask = SortedDict()  # Key is price as int, value is Level

        # Uniqueness of the keys is guaranteed only for active orders, i.e., if an order is removed, another order with the same key can be added
        self.activeOrders = {
        }  # Unordered map of active orders; Key is order Id, value is Order. Used for quick search of orders.
        # Otherwise, we need to iterate over the levels (bid and ask) and to check the orders for the orderId in question

    def isBidEmpty(self):
        return len(self.bid) == 0

    def isAskEmpty(self):
        return len(self.ask) == 0

    def isEmpty(self):
        return len(self.activeOrders) == 0

    def isPresent(self, anOrderId):
        return anOrderId in self.activeOrders

    def addOrder(self,
                 isBuy=None,
                 orderId=None,
                 price=None,
                 qty=None,
                 peakSize=None,
                 order=None):
        '''| TODO 
        | Creates and adds an order to the map of orders. In addition, pointer to the order is added to the proper map (bid/ask) and
        | vector (price level). The maps are ordered, therefore, inserting elements with price as keys, automatically builds a correct
        | depth of the book.
        | Note: best bid is the last element (price level) of the bid map; best ask is the first element (price level) of the ask map.
        |________'''

        # Already checked that an order with the same Id is not present in the book
        myOrder = Order(isBuy, orderId, price, qty,
                        peakSize) if order == None else order

        self.activeOrders[myOrder.orderId] = myOrder

        # TODO: Where do we deal with int*100 price as keys?
        key_price = int(myOrder.price * 100)

        level = self.bid if myOrder.isBuy else self.ask

        if key_price not in level:
            level[key_price] = Level()

        level[key_price].addOrder(myOrder)

    def removeOrder(self, orderId):
        '''| TODO 
        | Removes an active order from the book if present (return false if order not found).
        | In case of icebergs, removes both the visible and hidden parts.
        |________'''
        if orderId in self.activeOrders:
            isBuy = self.activeOrders[orderId].isBuy
            key_price = int(self.activeOrders[orderId].price * 100)

            level = self.bid if isBuy else self.ask
            level[key_price].remove(orderId)

            del self.activeOrders[orderId]
            return True

        return False

    def removeActiveOrder(self, orderId):
        '''| TODO 
        |________'''
        if orderId in self.activeOrders:
            del self.activeOrders[orderId]
            return True

        return False

    def removeEmptyLevels(self):
        '''| TODO 
        | If an incoming order executes and matches with all active orders of the best level
        | including visible and invisible part of the orders, the level is considered empty
        | and the matching continues with the next price level. After the execution, before
        | adding an order and processing a new incoming order, this function is used to remove
        | all empty levels. The book state is updated with new best (bid/ask) levels.
        |________'''
        for price in self.bid.keys():
            if self.bid[price].isEmpty():
                del self.bid[price]

        for price in self.ask.keys():
            if self.ask[price].isEmpty():
                del self.ask[price]

    def clear(self):
        self.activeOrders.clear()
        self.bid.clear()
        self.ask.clear()

    def show(self):
        '''| TODO 
        | Called as a result of command 's'
        | Since the best price is listed first, and the maps used to store the levels are ordered,
        | this function outputs the bid levels by traversing the bid map in reverse (highest price first)
        |________'''

        if self.isEmpty():
            print "Book --- EMPTY ---"

        else:
            if len(self.bid) == 0:
                print "Bid depth --- EMPTY ---"
            else:
                print "Bid depth (highest priority at top):"
                print "Price     ", "Order Id  ", "Quantity  ", "Iceberg"

                # Highest price first
                for _, level in self.bid.iteritems():
                    level.show()
                print

            if len(self.ask) == 0:
                print "Ask depth --- EMPTY ---"
            else:
                print "Ask depth (highest priority at top):"
                print "Price     ", "Order Id  ", "Quantity  ", "Iceberg"

                # Lowest price first
                for _, level in self.ask.iteritems():
                    level.show()
        print
Пример #25
0
def open_file(path):

  book = xlrd.open_workbook(path)
  sheet = book.sheet_by_index(0)

  major_name = []
  numStudByMajor1y = []
  numStudByMajor2y = []
  for row_index in xrange(1, sheet.nrows):
    row = sheet.row_values(row_index)
    if row[4] not in major_name:
      major_name.append(row[4])
    if len(major_name) is 18:
      break

  flow = dict([(major,{}) for major in major_name])
  numStudByMajor1y = dict([(major,{}) for major in major_name])
  numStudByMajor2y = dict([(major,{}) for major in major_name])

  for key in flow:
    flow[key] = dict([(major,{}) for major in major_name])
    numStudByMajor1y[key]["Total"] = 0;
    numStudByMajor1y[key]["Male"] = 0;
    numStudByMajor1y[key]["Female"] = 0;
    numStudByMajor1y[key]["Good"] = 0;
    numStudByMajor1y[key]["Moderate"] = 0;
    numStudByMajor1y[key]["Poor"] = 0;

    numStudByMajor2y[key]["Total"] = 0;
    numStudByMajor2y[key]["Male"] = 0;
    numStudByMajor2y[key]["Female"] = 0;
    numStudByMajor2y[key]["Good"] = 0;
    numStudByMajor2y[key]["Moderate"] = 0;
    numStudByMajor2y[key]["Poor"] = 0;

    for key1 in flow[key]:
      flow[key][key1]["_1Total"] = 0
      flow[key][key1]["_1Male"] = 0
      flow[key][key1]["_1Female"] = 0
      flow[key][key1]["_1Good"] = 0
      flow[key][key1]["_1Moderate"] = 0
      flow[key][key1]["_1Poor"] = 0

      flow[key][key1]["_2Total"] = 0
      flow[key][key1]["_2Male"] = 0
      flow[key][key1]["_2Female"] = 0
      flow[key][key1]["_2Good"] = 0
      flow[key][key1]["_2Moderate"] = 0
      flow[key][key1]["_2Poor"] = 0

  for key in flow:
    print flow[key]

  for row_index in xrange(1, sheet.nrows):
    row = sheet.row_values(row_index)
    numStudByMajor1y[row[2]]["Total"] += row[12];
    numStudByMajor1y[row[2]][row[0]] += row[12];
    numStudByMajor1y[row[2]][row[6]] += row[12];

    numStudByMajor2y[row[3]]["Total"] += row[12];
    numStudByMajor2y[row[3]][row[0]] += row[12];
    numStudByMajor2y[row[3]][row[6]] += row[12];

  for row_index in xrange(1, sheet.nrows):
    row = sheet.row_values(row_index)
    flow[row[2]][row[3]]["_1Total"] += row[12] / numStudByMajor1y[row[2]]["Total"]
    flow[row[2]][row[3]]["_1" + row[0]] += row[12] / numStudByMajor1y[row[2]][row[0]]
    flow[row[2]][row[3]]["_1" + row[6]] += row[12] / numStudByMajor1y[row[2]][row[6]]

    flow[row[3]][row[4]]["_2Total"] += row[12] / numStudByMajor2y[row[3]]["Total"]
    flow[row[3]][row[4]]["_2" + row[0]] += row[12] / numStudByMajor2y[row[3]][row[0]]
    flow[row[3]][row[4]]["_2" + row[6]] += row[12] / numStudByMajor2y[row[3]][row[6]]

  flow = SortedDict(flow)
  for key in flow:
    flow[key] = SortedDict(flow[key])

  with open('flow.csv', 'wb') as testfile:
      csv_writer = csv.writer(testfile)
      title = []
      title.extend(flow.keys())
      csv_writer.writerow(title)
      for key in flow:
        row = []
        row.extend(flow[key].values())
        csv_writer.writerow(row)
Пример #26
0
class GazeEvents:
    Timestamp = "Timestamp"
    LoggedEvents = "Logged Events"
    GazeType = "Gaze Type"
    EventIndex = "Event Index"
    EventDuration = "Event Duration"
    Fixation_X = "Fixation X"
    Fixation_Y = "Fixation Y"
    AOI_Mapped_Fixation_X = "AOI_Mapped_Fixation_X"
    AOI_Mapped_Fixation_Y = "AOI_Mapped_Fixation_Y"
    AOI = "AOI"
    AOI_Score = "AOI Score"
    Saccade_Start_X = "Saccade Start X"
    Saccade_Start_Y = "Saccade Start Y"
    Saccade_End_X = "Saccade End X"
    Saccade_End_Y = "Saccade End Y"
    InputFixationFilterX = GazeData.GazePixelX
    InputFixationFilterY = GazeData.GazePixelY

    def __init__(self):
        self.__events__ = SortedDict({})
        self.__init_datatypes__()
        self.__fixation_index__ = 0
        self.__saccade_index__ = 0
        self.__ts_processed__ = []
        self.__fixations__ = SortedDict({})
        self.__saccades__ = {}

    def __getitem__(self, key):
        return self.__events__[key]

    def __getFilteredGazeData__(self, gazedata_df, ts_filter=None):
        if ts_filter is None:
            ts_filter = TimestampFilter(list(gazedata_df[GazeData.Timestamp]))
        df, ts_list = ts_filter.getFilteredData(gazedata_df,
                                                self.__ts_processed__)
        self.__ts_processed__.extend(ts_list)
        return df, ts_list

    def __init_datatypes__(self):
        self.__events__[GazeEvents.Timestamp] = GazeItem(
            GazeEvents.Timestamp, np.dtype('float'))
        self.__events__[GazeEvents.LoggedEvents] = GazeItem(
            GazeEvents.LoggedEvents, np.dtype(object))
        self.__events__[GazeEvents.GazeType] = GazeItem(
            GazeEvents.GazeType, np.dtype(object))
        self.__events__[GazeEvents.Fixation_X] = GazeItem(
            GazeEvents.Fixation_X, np.dtype('u4'))
        self.__events__[GazeEvents.Fixation_Y] = GazeItem(
            GazeEvents.Fixation_Y, np.dtype('u4'))
        self.__events__[GazeEvents.EventIndex] = GazeItem(
            GazeEvents.EventIndex, np.dtype('u4'))
        self.__events__[GazeEvents.EventDuration] = GazeItem(
            GazeEvents.EventDuration, np.dtype('u4'))
        self.__events__[GazeEvents.AOI_Mapped_Fixation_X] = GazeItem(
            GazeEvents.AOI_Mapped_Fixation_X, np.dtype('u4'))
        self.__events__[GazeEvents.AOI_Mapped_Fixation_Y] = GazeItem(
            GazeEvents.AOI_Mapped_Fixation_Y, np.dtype('u4'))
        self.__events__[GazeEvents.AOI] = GazeItem(GazeEvents.AOI,
                                                   np.dtype(object))
        self.__events__[GazeEvents.AOI_Score] = GazeItem(
            GazeEvents.AOI_Score, np.dtype('f2'))
        self.__events__[GazeEvents.Saccade_Start_X] = GazeItem(
            GazeEvents.Saccade_Start_X, np.dtype('f2'))
        self.__events__[GazeEvents.Saccade_Start_Y] = GazeItem(
            GazeEvents.Saccade_Start_Y, np.dtype('f2'))
        self.__events__[GazeEvents.Saccade_End_X] = GazeItem(
            GazeEvents.Saccade_End_X, np.dtype('f2'))
        self.__events__[GazeEvents.Saccade_End_Y] = GazeItem(
            GazeEvents.Saccade_End_Y, np.dtype('f2'))

    def addFixation(self, ts, index, duration, fixation_x, fixation_y):
        self.__events__[GazeEvents.Timestamp][ts] = ts
        self.__events__[GazeEvents.GazeType][ts] = "Fixation"
        self.__events__[GazeEvents.EventIndex][ts] = index
        self.__events__[GazeEvents.EventDuration][ts] = duration
        self.__events__[GazeEvents.Fixation_X][ts] = fixation_x
        self.__events__[GazeEvents.Fixation_Y][ts] = fixation_y
        self.__fixations__[ts] = (fixation_x, fixation_y, duration)

    def addLoggedEvent(self, ts, logged_event):
        self.__events__[GazeEvents.Timestamp][ts] = ts
        self.__events__[GazeEvents.LoggedEvents][ts] = logged_event

    def addSaccade(self, ts, index, duration, saccade_start_x, saccade_start_y,
                   saccade_end_x, saccade_end_y):
        self.__events__[GazeEvents.Timestamp][ts] = ts
        self.__events__[GazeEvents.GazeType][ts] = "Saccade"
        self.__events__[GazeEvents.EventIndex][ts] = index
        self.__events__[GazeEvents.EventDuration][ts] = duration
        self.__events__[GazeEvents.Saccade_Start_X][ts] = saccade_start_x
        self.__events__[GazeEvents.Saccade_Start_Y][ts] = saccade_start_y
        self.__events__[GazeEvents.Saccade_End_X][ts] = saccade_end_x
        self.__events__[GazeEvents.Saccade_End_Y][ts] = saccade_end_y
        self.__saccades__ = (saccade_start_x, saccade_start_y, saccade_end_x,
                             saccade_end_y, duration)

    def exportCSV(self, filepath, filename, ts_filter=None):
        fixations_df = self.toDataFrame(ts_filter).dropna(
            subset=[GazeEvents.Fixation_X, GazeEvents.Fixation_Y])
        exp = FixationsCSV(filepath, filename, fixations_df)
        exp.toCSV()

    def exportDF(self, filepath, filename, ts_filter=None):
        logging.info('Exporting gaze events in %s' % filename)
        path = os.path.join(filepath, filename)
        self.toDataFrame(ts_filter).dropna(
            subset=[GazeEvents.Fixation_X, GazeEvents.Fixation_Y]).to_pickle(
                path)

    def filterFixations(self, fixation_filter, gazedata_df, ts_filter=None):
        df, ts_list = self.__getFilteredGazeData__(gazedata_df, ts_filter)
        if not df.empty:
            x = pd.Series(df[GazeEvents.InputFixationFilterX])
            y = pd.Series(df[GazeEvents.InputFixationFilterY])
            fixation_filter.setData(x, y)
            fixation_filter.filter(self)

    def getFixations(self, ts_filter=None):
        df = self.toDataFrame(ts_filter)
        return df.loc[df[GazeEvents.GazeType] == 'Fixation']

    def getClosestFixation(self, ts):
        ts_list = list(self.__fixations__.keys())
        ts_index = bisect.bisect_left(ts_list, ts)
        if ts_index == 0:
            ts_closest = ts_list[ts_index]
        else:
            ts_closest = ts_list[ts_index - 1]
        return self.__fixations__[ts_closest]

    def getSaccades(self, ts_filter=None):
        df = self.toDataFrame(ts_filter)
        return df.loc[df[GazeEvents.GazeType] == 'Saccade']

    def getFixationsAsNumpy(self, ts_filter):
        fixations_df = self.toDataFrame(ts_filter).dropna(
            subset=[GazeEvents.Fixation_X, GazeEvents.Fixation_Y])
        x = fixations_df[GazeEvents.Fixation_X].values
        y = fixations_df[GazeEvents.Fixation_Y].values
        ts_list = fixations_df[GazeEvents.Timestamp].values.tolist()
        return (ts_list, x, y)

    def getTimestamps(self):
        return list(self.__events__[GazeEvents.Timestamp].values())

    def setAOI(self, ts, aoi_fixation_x, aoi_fixation_y, aoi_label, aoi_score):
        self.__events__[GazeEvents.AOI_Mapped_Fixation_X][ts] = aoi_fixation_x
        self.__events__[GazeEvents.AOI_Mapped_Fixation_Y][ts] = aoi_fixation_y
        self.__events__[GazeEvents.AOI][ts] = aoi_label
        self.__events__[GazeEvents.AOI_Score][ts] = aoi_score

    def toDataFrame(self, ts_filter=None):
        table = {}
        for label, data in self.__events__.items():
            table[label] = pd.Series(data.getData(), dtype=object)
        df = pd.DataFrame(table)
        if ts_filter is None:
            return df
        else:
            filtered_df, ts_list = ts_filter.getFilteredData(df)
            return filtered_df

    def to_pickle(self, filename, ts_filter=None):
        self.toDataFrame(ts_filter).to_pickle(filename)
Пример #27
0
class WordData(QObject):
    # Define the signal we emit when we have loaded new data
    WordsUpdated = pyqtSignal()

    def __init__(self, my_book):
        super().__init__(None)
        # Save reference to the book
        self.my_book = my_book
        # Save reference to the metamanager
        self.metamgr = my_book.get_meta_manager()
        # Save reference to the edited document
        self.document = my_book.get_edit_model()
        # Save reference to a speller, which will be the default
        # at this point.
        self.speller = my_book.get_speller()
        # The vocabulary list as a sorted dict.
        self.vocab = SortedDict()
        # Key and Values views on the vocab list for indexing by table row.
        self.vocab_kview = self.vocab.keys()
        self.vocab_vview = self.vocab.values()
        # The count of available words based on the latest sort
        self.active_word_count = 0
        # The good- and bad-words sets and the scannos set.
        self.good_words = set()
        self.bad_words = set()
        self.scannos = set()
        # A dict of words that use an alt-dict tag. The key is a word and the
        # value is the alt-dict tag string.
        self.alt_tags = SortedDict()
        # Cached sort vectors, see get_sort_vector()
        self.sort_up_vectors = [None, None, None]
        self.sort_down_vectors = [None, None, None]
        self.sort_key_funcs = [None, None, None]
        # Register metadata readers and writers.
        self.metamgr.register(C.MD_GW, self.good_read, self.good_save)
        self.metamgr.register(C.MD_BW, self.bad_read, self.bad_save)
        self.metamgr.register(C.MD_SC, self.scanno_read, self.scanno_save)
        self.metamgr.register(C.MD_VL, self.word_read, self.word_save)
    # End of __init__


    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    # Methods used when saving metadata. The items in the good_words,
    # bad_words, and scanno sets are simply returned as a list of strings.
    #
    def good_save(self, section) :
        return [ token for token in self.good_words ]

    def bad_save(self, section) :
        return [ token for token in self.bad_words ]

    def scanno_save(self, section) :
        return [ token for token in self.scannos ]
    #
    # To save the vocabulary, write a list for each word:
    #   [ "token", "tag", count, [prop-code...] ]
    # where "token" is the word as a string, "tag" is its alt-dict tag
    # or a null string, count is an integer and [prop-code...] is the
    # integer values from the word's property set as a list. Note that
    # alt_tag needs to be a string because json doesn't handle None.
    #
    def word_save(self, section) :
        vlist = []
        for word in self.vocab:
            [count, prop_set] = self.vocab[word]
            #tag = "" if AD not in prop_set else self.alt_tags[word]
            tag = ""
            if AD in prop_set :
                if word in self.alt_tags :
                    tag = self.alt_tags[word]
                else : # should never occur, could be assertion error
                    worddata_logger.error( 'erroneous alt tag on ' + word )
            plist = list(prop_set)
            vlist.append( [ word, count, tag, plist ] )
        return vlist

    #
    # Methods used to load metadata. Called by the metadata manager with
    # a single Python object, presumably the object that was prepared by
    # the matching _save method above. Because the user might edit the metadata
    # file, do a little quality control.
    #

    def good_read(self, section, value, version):
        if isinstance(value, list) :
            for token in value :
                if isinstance(token, str) :
                    if token in self.bad_words :
                        worddata_logger.warn(
                            '"{}" is in both good and bad words - use in good ignored'.format(token)
                            )
                    else :
                        self.good_words.add(token)
                        if token in self.vocab : # vocab already loaded, it seems
                            props = self.vocab[token][1]
                            props.add(GW)
                            props &= prop_nox
                else :
                    worddata_logger.error(
                        '{} in GOODWORDS list ignored'.format(token)
                        )
            if len(self.good_words) :
                # We loaded some, the display might need to change
                self.WordsUpdated.emit()
        else :
            worddata_logger.error(
                'GOODWORDS metadata is not a list of strings, ignoring it'
                )

    def bad_read(self, section, value, version):
        if isinstance(value, list) :
            for token in value :
                if isinstance(token, str) :
                    if token in self.good_words :
                        worddata_logger.warn(
                            '"{}" is in both good and bad words - use in bad ignored'.format(token)
                            )
                    else :
                        self.bad_words.add(token)
                        if token in self.vocab : # vocab already loaded, it seems
                            props = self.vocab[token][1]
                            props.add(BW)
                            props.add(XX)
                else :
                    worddata_logger.error(
                        '{} in BADWORDS list ignored'.format(token)
                        )
            if len(self.bad_words) :
                # We loaded some, the display might need to change
                self.WordsUpdated.emit()
        else :
            worddata_logger.error(
                'BADWORDS metadata is not a list of strings, ignoring it'
                )

    def scanno_read(self, section, value, version):
        if isinstance(value, list) :
            for token in value :
                if isinstance(token, str) :
                    self.scannos.add(token)
                else :
                    worddata_logger.error(
                        '{} in SCANNOLIST ignored'.format(token)
                        )
        else :
            worddata_logger.error(
                'SCANNOLIST metadata is not a list of strings, ignoring it'
                )

    # Load the vocabulary section of a metadata file, allowing for
    # user-edited malformed items. Be very generous about user errors in a
    # modified meta file. The expected value for each word is as written by
    # word_save() above, ["token", count, tag, [props]] but allow a single
    # item ["token"] or just "token" so the user can put in a single word
    # with no count or properties. Convert null-string alt-tag to None.
    #
    # Before adding a word make sure to unicode-flatten it.
    #
    def word_read(self, section, value, version) :
        global PROP_ALL, prop_nox
        # get a new speller in case the Book read a different dict already
        self.speller = self.my_book.get_speller()
        # if value isn't a list, bail out now
        if not isinstance(value,list):
            worddata_logger.error(
                'WORDCENSUS metadata is not a list, ignoring it'
                )
            return
        # inspect each item of the list.
        for wlist in value:
            try :
                if isinstance(wlist,str) :
                    # expand "token" to ["token"]
                    wlist = [wlist]
                if not isinstance(wlist, list) : raise ValueError
                if len(wlist) != 4 :
                    if len(wlist) > 4 :raise ValueError
                    if len(wlist) == 1 : wlist.append(0) # add default count of 0
                    if len(wlist) == 2 : wlist.append('') # add default alt-tag
                    if len(wlist) == 3 : wlist.append([]) # add default props
                word = wlist[0]
                if not isinstance(word,str) : raise ValueError
                word = unicodedata.normalize('NFKC',word)
                count = int(wlist[1]) # exception if not numeric
                alt_tag = wlist[2]
                if not isinstance(alt_tag,str) : raise ValueError
                if alt_tag == '' : alt_tag = None
                prop_set = set(wlist[3]) # exception if not iterable
                if len( prop_set - PROP_ALL ) : raise ValueError #bogus props
            except :
                worddata_logger.error(
                    'WORDCENSUS item {} is invalid, ignoring it'.format(wlist)
                    )
                continue
            # checking done, store the word.
            if (0 == len(prop_set)) or (0 == count) :
                # word with no properties or count is a user addition, enter
                # it as if we found it in the file, including deducing the
                # properties, spell-check, hyphenation split.
                self._add_token(word, alt_tag)
                continue # that's that, on to next line
            # Assume we have a word saved by word_save(), but possibly the
            # good_words and bad_words have been edited and read-in first.
            # Note we are not checking for duplicates
            if word in self.bad_words :
                prop_set.add(BW)
                prop_set.add(XX)
            if word in self.good_words :
                prop_set.add(GW)
                prop_set &= prop_nox
            if alt_tag :
                prop_set.add(AD)
                self.alt_tags[word] = alt_tag
            self.vocab[word] = [count, prop_set]
        # end of "for wlist in value"
        # note the current word count
        self.active_word_count = len(self.vocab)
        # Tell wordview that the display might need to change
        self.WordsUpdated.emit()
    # end of word_read()

    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    # Methods used when opening a new file, one with no metadata.
    #
    # The Book will call these methods passing a text stream when it finds a
    # good-words file or bad-words file. Each of these is expected to have
    # one token per line. We don't presume to know in what order the files
    # are presented, but we DO assume that the vocabulary census has not yet
    # been taken. That requires the user clicking Refresh and that cannot
    # have happened while first opening the file.

    def good_file(self, stream) :
        while not stream.atEnd() :
            token = stream.readLine().strip()
            if token in self.bad_words :
                worddata_logger.warn(
                    '"{}" is in both good and bad words - use in good ignored'.format(token)
                    )
            else :
                self.good_words.add(token)

    def bad_file(self, stream) :
        while not stream.atEnd() :
            token = stream.readLine().strip()
            if token in self.good_words :
                worddata_logger.warn(
                    '"{}" is in both good and bad words - use in bad ignored'.format(token)
                    )
            else :
                self.bad_words.add(token)
    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    #
    # The user can choose a new scannos file any time while editing. So there
    # might be existing data, so we clear the set before reading.
    #
    def scanno_file(self, stream) :
        self.scannos = set() # clear any prior values
        while not stream.atEnd() :
            token = stream.readLine().strip()
            self.scannos.add(token)

    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    # The following is called by the Book when the user chooses a different
    # spelling dictionary. Store a new spellcheck object. Recheck the
    # spelling of all words except those with properties HY, GW, or BW.
    #
    # NOTE IF THIS IS A PERFORMANCE BURDEN, KILL IT AND REQUIRE REFRESH
    #
    def recheck_spelling(self, speller):
        global PROP_BGH, prop_nox
        self.speller = speller
        for i in range(len(self.vocab)) :
            (c, p) = self.vocab_vview[i]
            if not( PROP_BGH & p ) : # then p lacks BW, GW and HY
                p = p & prop_nox # and now it also lacks XX
                w = self.vocab_kview[i]
                t = self.alt_tags.get(w,None)
                if not self.speller.check(w,t):
                    p.add(XX)
                self.vocab_vview[i][1] = p

    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    # Method to perform a census. This is called from wordview when the
    # user clicks the Refresh button asking for a new scan over all words in
    # the book. Formerly this took a progress bar, but the actual operation
    # is so fast no progress need be shown.
    #
    def refresh(self):
        global RE_LANG_ATTR, RE_TOKEN

        count = 0
        end_count = self.document.blockCount()

        # get a reference to the dictionary to use
        self.speller = self.my_book.get_speller()
        # clear the alt-dict list.
        self.alt_tags = SortedDict()
        # clear the sort vectors
        self.sort_up_vectors = [None, None, None]
        self.sort_down_vectors = [None, None, None]
        self.sort_key_funcs = [None, None, None]
        # Zero out all counts and property sets that we have so far. We will
        # develop new properties when each word is first seen. Properties
        # such as HY will not have changed, but both AD and XX might have
        # changed while the word text remains the same.
        for j in range(len(self.vocab)) :
            self.vocab_vview[j][0] = 0
            self.vocab_vview[j][1] = set()

        # iterate over all lines extracting tokens and processing them.
        alt_dict = None
        alt_tag = None
        for line in self.document.all_lines():
            count += 1
            j = 0
            m = RE_TOKEN.search(line,0)
            while m : # while match is not None
                if m.group(6) : # start-tag; has it lang= ?
                    d = RE_LANG_ATTR.search(m.group(8))
                    if d :
                        alt_dict = d.group(1)
                        alt_tag = m.group(7)
                elif m.group(9) :
                    if m.group(10) == alt_tag :
                        # end tag of a lang= start tag
                        alt_dict = None
                        alt_tag = None
                else :
                    self._add_token(m.group(0),alt_dict)
                j = m.end()
                m = RE_TOKEN.search(line,j)
        # Look for zero counts and delete those items. It is forbidden to
        # alter the dict contents while iterating over values or keys views,
        # so make a list of the word tokens to be deleted, then use del.
        togo = []
        for j in range(len(self.vocab)) :
            if self.vocab_vview[j][0] == 0 :
                togo.append(self.vocab_kview[j])
        for key in togo:
            del self.vocab[key]
        # Update possibly modified word count
        self.active_word_count = len(self.vocab)

    # Internal method for adding a possibly-hyphenated token to the vocabulary,
    # incrementing its count. This is used during the census/refresh scan, and
    # can be called from word_read to process a user-added word.
    # Arguments:
    #    tok_str: a normalized word-like token; may be hyphenated a/o apostrophized
    #    dic_tag: an alternate dictionary tag or None
    #
    # If the token has no hyphens, this is just a cover on _count. When the
    # token is hyphenated, we enter each part of it alone, then add the
    # phrase with the union of the prop_sets of its parts, plus HY. Thus
    # "mother-in-law's" will be added as "mother", "in" and "law's", and as
    # itself with HY, LC, AP. "1989-1995" puts 1989 and 1995 in the list and
    # will have HY and ND. Yes, this means that a hyphenation could have all
    # of UC, MC and LC.
    #
    # If a part of a phrase fails spellcheck, it will have XX but we do not
    # propogate that to the phrase itself.
    #
    # If a part of the phrase has AD (because it was previously entered as
    # part of a lang= string) that also is not propogated to the phrase
    # itself. Since hyphenated phrases are never spell-checked, they should
    # never have AD.
    #
    # Note: en-dash \u2013 is not supported here, only the ascii hyphen.
    # Support for it could be added if required.
    #
    # Defensive programming: '-'.split('-') --> ['','']; '-9'.split('-') --> ['','9']

    def _add_token(self, tok_str, dic_tag ) :
        global prop_nox
        # Count the entire token regardless of hyphens
        self._count(tok_str, dic_tag) # this definitely puts it in the dict
        [count, prop_set] = self.vocab[tok_str]
        if (count == 1) and (HY in prop_set) :
            # We just added a hyphenated token: add its parts also.
            parts = tok_str.split('-')
            prop_set = {HY}
            for member in parts :
                if len(member) : # if not null split from leading -
                    self._count(member, dic_tag)
                    [x, part_props] = self.vocab[member]
                    prop_set |= part_props
            self.vocab[tok_str] = [count, prop_set  - {XX, AD} ]

    # Internal method to count a token, adding it to the list if necessary.
    # An /alt-tag must already be removed. The word must be already
    # normalized. Because of the way we tokenize, we know the token contains
    # only letter forms, numeric forms, and possibly hyphens and/or
    # apostrophes.
    #
    # If it is in the list, increment its count. Otherwise, compute its
    # properties, including spellcheck for non-hyphenated tokens, and
    # add it to the vocabulary with a count of 1. Returns nothing.

    def _count(self, word, dic_tag ) :
        [count, prop_set] = self.vocab.get( word, [0,set()] )
        if count : # it was in the list: a new word would have count=0
            self.vocab[word][0] += 1 # increment its count
            return # and done.
        # Word was not in the list (but is now): count is 0, prop_set is empty.
        # The following is only done once per unique word.
        self.my_book.metadata_modified(True, C.MD_MOD_FLAG)
        work = word[:] # copy the word, we may modify it next.
        if work.startswith("Point"):
            pass # debug
        # If word has apostrophes, note that and delete for following tests.
        if -1 < work.find("'") : # look for ascii apostrophe
            prop_set.add(AP)
            work = work.replace("'","")
        if -1 < work.find('\u02bc') : # look for MODIFIER LETTER APOSTROPHE
            prop_set.add(AP)
            work = work.replace('\u02bc','')
        # If word has hyphens, note that and remove them.
        if -1 < work.find('-') :
            prop_set.add(HY)
            work = work.replace('-','')
        # With the hyphens and apostrophes out, check letter case
        if ANY_DIGIT.search( work ) :
            # word has at least one numeric
            prop_set.add(ND)
        if not work.isnumeric() :
            # word is not all-numeric, determine case of letters
            if work.lower() == work :
                prop_set.add(LC) # most common case
            elif work.upper() != work :
                prop_set.add(MC) # next most common case
            else : # work.upper() == work
                prop_set.add(UC)
        if HY not in prop_set : # word is not hyphenated, so check its spelling.
            if word not in self.good_words :
                if word not in self.bad_words :
                    # Word in neither good- nor bad-words
                    if dic_tag : # uses an alt dictionary
                        self.alt_tags[word] = dic_tag
                        prop_set.add(AD)
                    if not self.speller.check(word, dic_tag) :
                        prop_set.add(XX)
                else : # in bad-words
                    prop_set.add(XX)
            # else in good-words
        # else hyphenated, spellcheck only its parts
        self.vocab[word] = [1, prop_set]

    # =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
    #
    # The following methods are called from the Words panel.
    #
    #  Get the count of words in the vocabulary, as selected by the
    #  latest sort vector.
    #
    def word_count(self):
        return self.active_word_count
    #
    # Get the actual size of the vocabulary, for searching it all.
    def vocab_count(self):
        return len(self.vocab)
    #
    # Get the word at position n in the vocabulary, using the SortedDict
    # KeysView for O(1) lookup time. Guard against invalid indices.
    #
    def word_at(self, n):
        try:
            return self.vocab_kview[n]
        except Exception as whatever:
            worddata_logger.error('bad call to word_at({0})'.format(n))
            return ('?')
    #
    # Get the count and/or property-set of the word at position n in the
    # vocabulary, using the SortedDict ValuesView for O(1) lookup time.
    #
    def word_info_at(self, n):
        try:
            return self.vocab_vview[n]
        except Exception as whatever:
            worddata_logger.error('bad call to word_count_at({0})'.format(n))
            return [0, set()]
    def word_count_at(self, n):
        try:
            return self.vocab_vview[n][0]
        except Exception as whatever:
            worddata_logger.error('bad call to word_count_at({0})'.format(n))
            return 0
    def word_props_at(self, n):
        try:
            return self.vocab_vview[n][1]
        except Exception as whatever:
            worddata_logger.error('bad call to word_props_at({0})'.format(n))
            return (set())

    #
    # Return a sort vector to implement column-sorting and/or filtering. The
    # returned value is a list of index numbers to self.vocab_vview and
    # vocab_kview such that iterating over the list selects vocabulary items
    # in some order. The parameters are:
    #
    # col is the number of the table column, 0:word, 1:count, 2:properties.
    # The sort key is formed based on the column:
    #   0: key is the word-token
    #   1: key is nnnnnnword-token so that words with the same count are
    #      in sequence.
    #   2: fffffffword-token so that words with the same props are in sequence.
    #
    # order is Qt.AscendingOrder or Qt.DescendingOrder
    #
    # key_func is a callable used to extract or condition the key value when
    # a new key is added to a SortedDict, usually created by natsort.keygen()
    # and used to implement locale-aware and case-independent sorting.
    #
    # filter_func is a callable that examines a vocab entry and returns
    # True or False, meaning include or omit this entry from the vector.
    # Used to implement property filters or harmonic-sets.
    #
    # To implement Descending order we return a reversed() version of the
    # matching Ascending order vector.
    #
    # Because vectors are expensive to make, we cache them, so that to
    # return to a previous sort order takes near zero time. However we can't
    # cache every variation of a filtered vector, so when a filter_func is
    # passed we make the vector every time.
    #
    def _make_key_getter(self, col) :
        if col == 0 :
            return lambda j : self.vocab_kview[j]
        elif col == 1 :
            return lambda j : '{:05}:{}'.format( self.vocab_vview[j][0], self.vocab_kview[j] )
        else : # col == 2
            return lambda j : prop_string(self.vocab_vview[j][1]) + self.vocab_kview[j]

    def get_sort_vector( self, col, order, key_func = None, filter_func = None ) :
        if filter_func : # is not None,
            # create a sort vector from scratch, filtered
            getter_func = self._make_key_getter( col )
            sorted_dict = SortedDict( key_func )
            for j in range( len( self.vocab ) ) :
                if filter_func( self.vocab_kview[j], self.vocab_vview[j][1] ) :
                    k = getter_func( j )
                    sorted_dict[ k ] = j
            vector = sorted_dict.values()
            if order != Qt.AscendingOrder :
                vector = [j for j in reversed( vector ) ]
        else : # no filter_func, try to reuse a cached vector
            vector = self.sort_up_vectors[ col ]
            if not vector or key_func is not self.sort_key_funcs[ col ] :
                # there is no ascending vector for this column, or there
                # is one but it was made with a different key_func.
                getter_func = self._make_key_getter( col )
                sorted_dict = SortedDict( key_func )
                for j in range( len( self.vocab ) ) :
                    k = getter_func( j )
                    sorted_dict[ k ] = j
                vector = self.sort_up_vectors[ col ] = sorted_dict.values()
                self.sort_key_funcs[ col ] = key_func
            if order != Qt.AscendingOrder :
                # what is wanted is a descending order vector, do we have one?
                if self.sort_down_vectors[ col ] is None :
                    # no, so create one from the asc. vector we now have
                    self.sort_down_vectors[ col ] = [ j for j in reversed( vector ) ]
                # yes we do (now)
                vector = self.sort_down_vectors[ col ]
        # one way or another, vector is a sort vector
        # note the actual word count available through that vector
        self.active_word_count = len(vector)
        return vector

    # Return a reference to the good-words set
    def get_good_set(self):
        return self.good_words

    # Note the addition of a word to the good-words set. The word probably
    # (but does not have to) exist in the database; add GW and remove XX from
    # its properties.
    def add_to_good_set(self, word):
        self.good_words.add(word)
        if word in self.vocab_kview :
            [count, pset] = self.vocab[word]
            pset.add(GW)
            pset -= set([XX]) # conditional .remove()
            self.vocab[word] = [count,pset]

    # Note the removal of a word from the good-words set. The word exists in
    # the good-words set, because the wordview panel good-words list only
    # calls this for words it is displaying. The word may or may not exist in
    # the database. If it does, remove GW and set XX based on a spellcheck
    # test.
    def del_from_good_set(self, word):
        self.good_words.remove(word)
        if word in self.vocab_kview :
            [count, pset] = self.vocab[word]
            pset -= set([GW,XX])
            dic_tag = self.alt_tags.get(word)
            if not self.speller.check(word, dic_tag) :
                pset.add(XX)
            self.vocab[word] = [count, pset]

    # mostly used by unit test, get the index of a word by its key
    def word_index(self, w):
        try:
            return self.vocab_kview.index(w)
        except Exception as whatever:
            worddata_logger.error('bad call to word_index({0})'.format(w))
            return -1

    # The following methods are used by the edit syntax highlighter to set flags.
    #
    # 1. Check a token for spelling. We expect the vast majority of words
    # will be in the list. And for performance, we want to respond in as little
    # code as possible! So if we know the word, reply at once.
    #
    # 2. If the word in the document isn't in the vocab, perhaps it is not
    # a normalized string, so try again, normalized.
    #
    # 3 If the token is not in the list, add it to the vocabulary with null
    # properties (to speed up repeat calls) and return False, meaning it is
    # not misspelled. The opposite, returning True for misspelled, in a new
    # book before Refresh is done, would highlight everything.
    #
    def spelling_test(self, tok_str) :
        count, prop_set = self.vocab.get(tok_str,[0,set()])
        if count : # it was in the list
            return XX in prop_set
        tok_nlz = unicodedata.normalize('NFKC',tok_str)
        [count, prop_set] = self.vocab.get(tok_nlz,[0,set()])
        return XX in prop_set
    #
    # 2. Check a token for being in the scannos list. If no scannos
    # have been loaded, none will be hilited.
    #
    def scanno_test(self, tok_str) :
        return tok_str in self.scannos
Пример #28
0
class KITTI360Loader(TrackingDatasetBase):
    """
    Load KITTI-360 dataset into a usable format.
    The dataset structure should follow the official documents.

    * Zip Files::

        - calibration.zip
        - data_3d_bboxes.zip
        - data_3d_semantics.zip
        - data_poses.zip
        - data_timestamps_sick.zip
        - data_timestamps_velodyne.zip
        - 2013_05_28_drive_0000_sync_sick.zip
        - 2013_05_28_drive_0000_sync_velodyne.zip
        - ...

    * Unzipped Structure::

        - <base_path directory>
            - calibration
            - data_2d_raw
                - 2013_05_28_drive_0000_sync
                - ...
            - data_2d_semantics
                - 2013_05_28_drive_0000_sync
                - ...
            - data_3d_raw
                - 2013_05_28_drive_0000_sync
                - ...
            - data_3d_semantics
                - 2013_05_28_drive_0000_sync
                - ...

    For description of constructor parameters, please refer to :class:`d3d.dataset.base.TrackingDatasetBase`
    
    :param interpolate_pose: Not all frames contain pose data in KITTI-360. The loader
        returns interpolated pose if this param is set as True, otherwise returns None
    :type interpolate_pose: bool
    :param compression: The compression type of the created zip for semantic files. It should be one of the
        compression types specified in :mod:`zipfile` module.
    :type compression: int 
    """
    VALID_CAM_NAMES = ['cam1', 'cam2', 'cam3',
                       'cam4']  # cam 1,2 are persective
    VALID_LIDAR_NAMES = ['velo']  # velo stands for velodyne
    VALID_OBJ_CLASSES = Kitti360Class

    FRAME_PATH_MAP = dict(
        sick=("data_3d_raw", "sick_points", "data",
              "data_timestamps_sick.zip"),
        velo=("data_3d_raw", "velodyne_points", "data",
              "data_timestamps_velodyne.zip"),
        cam1=("data_2d_raw", "image_00", "data_rect",
              "data_timestamps_perspective.zip"),
        cam2=("data_2d_raw", "image_01", "data_rect",
              "data_timestamps_perspective.zip"),
        cam3=("data_2d_raw", "image_02", "data_rgb",
              "data_timestamps_fisheye.zip"),
        cam4=("data_2d_raw", "image_03", "data_rgb",
              "data_timestamps_fisheye.zip"),
    )

    def __init__(self,
                 base_path,
                 phase="training",
                 inzip=False,
                 trainval_split=1,
                 trainval_random=False,
                 trainval_byseq=False,
                 nframes=0,
                 interpolate_pose=True,
                 compression=ZIP_STORED):
        super().__init__(base_path,
                         inzip=inzip,
                         phase=phase,
                         nframes=nframes,
                         trainval_split=trainval_split,
                         trainval_random=trainval_random,
                         trainval_byseq=trainval_byseq)

        if phase not in ['training', 'validation', 'testing']:
            raise ValueError("Invalid phase tag")

        self.base_path = Path(base_path)
        self.inzip = inzip
        self.phase = phase
        self.nframes = nframes
        self.interpolate_pose = interpolate_pose
        self.compression = compression

        # count total number of frames
        frame_count = dict()
        _dates = ["2013_05_28"]
        if self.inzip:
            _archives = [
                # ("sick", ".bin"), # SICK points are out of synchronization
                ("velodyne", ".bin"),
                ("image_00", ".png"),
                ("image_01", ".png"),
                ("image_02", ".png"),
                ("image_03", ".png")
            ]
            for aname, ext in _archives:
                globs = [
                    self.base_path.glob(f"{date}_drive_*_sync_{aname}.zip")
                    for date in _dates
                ]
                for archive in chain(*globs):
                    with ZipFile(archive) as data:
                        data_files = (name for name in data.namelist()
                                      if name.endswith(ext))

                        seq = archive.stem[:archive.stem.rfind("_")]
                        frame_count[seq] = sum(1 for _ in data_files)

                # successfully found
                if len(frame_count) > 0:
                    break
        else:
            _folders = [
                # ("data_3d_raw", "sick_points", "data"),
                ("data_3d_raw", "velodyne_points", "data"),
                ("data_2d_raw", "image_00", "data_rect"),
                ("data_2d_raw", "image_01", "data_rect"),
                ("data_2d_raw", "image_02", "data_rgb"),
                ("data_2d_raw", "image_03", "data_rgb"),
            ]
            for ftype, fname, dname in _folders:
                globs = [
                    self.base_path.glob(f"{ftype}/{date}_drive_*_sync")
                    for date in _dates
                ]
                for archive in chain(*globs):
                    if not archive.is_dir():  # skip calibration files
                        continue
                    if not (archive / fname / "data").exists():
                        continue

                    seq = archive.name
                    frame_count[seq] = sum(
                        1 for _ in (archive / fname / dname).iterdir())

                # successfully found
                if len(frame_count) > 0:
                    break

        if not frame_count:
            raise ValueError(
                "Cannot parse dataset, please check path, inzip option and file structure"
            )
        self.frame_dict = SortedDict(frame_count)

        self.frames = split_trainval_seq(phase, self.frame_dict,
                                         self.frame_dict, trainval_random,
                                         trainval_byseq)
        self._poses_idx = {}  # store loaded poses indices
        self._poses_t = {}  # pose translation
        self._poses_r = {}  # pose rotation
        self._3dobjects_cache = {}  # store loaded object labels
        self._3dobjects_mapping = {}  # store frame to objects mapping
        self._timestamp_cache = {}  # store timestamps

        # load calibration
        self._calibration = None
        self._preload_calib()

    def __len__(self):
        return len(self.frames)

    @property
    def sequence_ids(self):
        return list(self.frame_dict.keys())

    @property
    def sequence_sizes(self):
        return dict(self.frame_dict)

    def _locate_frame(self, idx):
        # use underlying frame index
        idx = self.frames[idx]

        for k, v in self.frame_dict.items():
            if idx < (v - self.nframes):
                return k, idx
            idx -= (v - self.nframes)
        raise ValueError("Index larger than dataset size")

    @expand_idx_name(VALID_CAM_NAMES)
    def camera_data(self, idx, names='cam1'):
        seq_id, frame_idx = idx

        _, folder_name, dname, _ = self.FRAME_PATH_MAP[names]
        fname = Path(seq_id, folder_name, dname, '%010d.png' % frame_idx)
        if self._return_file_path:
            return self.base_path / "data_2d_raw" / fname

        if self.inzip:
            with PatchedZipFile(self.base_path / f"{seq_id}_{folder_name}.zip",
                                to_extract=fname) as source:
                return load_image(source, fname, gray=False)
        else:
            return load_image(self.base_path / "data_2d_raw",
                              fname,
                              gray=False)

    @expand_idx_name(['velo'])
    def lidar_data(self, idx, names='velo'):
        assert names == 'velo'
        seq_id, frame_idx = idx

        # load velodyne points
        fname = Path(seq_id, "velodyne_points", "data",
                     '%010d.bin' % frame_idx)
        if self._return_file_path:
            return self.base_path / "data_3d_raw" / fname

        if self.inzip:
            with PatchedZipFile(self.base_path / f"{seq_id}_velodyne.zip",
                                to_extract=fname) as source:
                return load_velo_scan(source, fname)
        else:
            return load_velo_scan(self.base_path / "data_3d_raw", fname)

    def _preload_3dobjects(self, seq_id):
        assert self.phase in ["training", "validation"
                              ], "Testing set doesn't contains label"
        if seq_id in self._3dobjects_mapping:
            return
        assert seq_id in self.sequence_ids

        fname = Path("data_3d_bboxes", "train", f"{seq_id}.xml")
        if self.inzip:
            with PatchedZipFile(self.base_path / "data_3d_bboxes.zip",
                                to_extract=fname) as source:
                objlist, fmap = load_bboxes(source, fname)
        else:
            objlist, fmap = load_bboxes(self.base_path, fname)

        self._3dobjects_cache[seq_id] = objlist
        self._3dobjects_mapping[seq_id] = fmap

    @expand_idx
    def annotation_3dobject(
        self,
        idx,
        raw=False,
        visible_range=80
    ):  # TODO: it seems that dynamic objects need interpolation
        '''
        :param visible_range: range for visible objects. Objects beyond that distance will be removed when reporting
        '''
        assert not self._return_file_path, "The annotation is not in a single file!"
        seq_id, frame_idx = idx
        self._preload_3dobjects(seq_id)
        objects = [
            self._3dobjects_cache[seq_id][i.data]
            for i in self._3dobjects_mapping[seq_id][frame_idx]
        ]
        if raw:
            return objects

        self._preload_poses(seq_id)
        pr, pt = self._poses_r[seq_id][frame_idx], self._poses_t[seq_id][
            frame_idx]

        boxes = Target3DArray(frame="pose")
        for box in objects:
            RS, T = box.transform[:3, :3], box.transform[:3, 3]
            S = np.linalg.norm(RS, axis=0)  # scale
            R = Rotation.from_matrix(RS / S)  # rotation
            R = pr.inv() * R  # relative rotation
            T = pr.inv().as_matrix().dot(T - pt)  # relative translation

            # skip static objects beyond vision
            if np.linalg.norm(T) > visible_range:
                continue

            global_id = box.semanticID * 1000 + box.instanceID
            tag = ObjectTag(kittiId2label[box.semanticId].name, Kitti360Class)
            boxes.append(ObjectTarget3D(T, R, S, tag, tid=global_id))
        return boxes

    def _preload_calib(self):
        # load data
        import yaml
        if self.inzip:
            source = ZipFile(self.base_path / "calibration.zip")
        else:
            source = self.base_path

        cam2pose = load_calib_file(source, "calibration/calib_cam_to_pose.txt")
        perspective = load_calib_file(source, "calibration/perspective.txt")
        if self.inzip:
            cam2velo = np.fromstring(
                source.read("calibration/calib_cam_to_velo.txt"), sep=" ")
            sick2velo = np.fromstring(
                source.read("calibration/calib_sick_to_velo.txt"), sep=" ")
            intri2 = yaml.safe_load(
                source.read("calibration/image_02.yaml")[10:])  # skip header
            intri3 = yaml.safe_load(
                source.read("calibration/image_03.yaml")[10:])
        else:
            cam2velo = np.loadtxt(source / "calibration/calib_cam_to_velo.txt")
            sick2velo = np.loadtxt(source /
                                   "calibration/calib_sick_to_velo.txt")
            intri2 = yaml.safe_load(
                (source / "calibration/image_02.yaml").read_text()[10:])
            intri3 = yaml.safe_load(
                (source / "calibration/image_03.yaml").read_text()[10:])

        if self.inzip:
            source.close()

        # parse calibration
        calib = TransformSet("pose")
        calib.set_intrinsic_lidar("velo")
        calib.set_intrinsic_lidar("sick")
        calib.set_intrinsic_camera("cam1",
                                   perspective["P_rect_00"].reshape(3, 4),
                                   perspective["S_rect_00"],
                                   rotate=False)
        calib.set_intrinsic_camera("cam2",
                                   perspective["P_rect_01"].reshape(3, 4),
                                   perspective["S_rect_01"],
                                   rotate=False)

        def parse_mei_camera(intri):
            size = [intri['image_width'], intri['image_height']]
            distorts = intri['distortion_parameters']
            distorts = np.array([
                distorts['k1'], distorts['k2'], distorts['p1'], distorts['p2']
            ])
            projection = intri['projection_parameters']
            pmatrix = np.diag([projection["gamma1"], projection["gamma2"], 1])
            pmatrix[0, 2] = projection["u0"]
            pmatrix[1, 2] = projection["v0"]
            return size, pmatrix, distorts, intri['mirror_parameters']['xi']

        S, P, D, xi = parse_mei_camera(intri2)
        calib.set_intrinsic_camera("cam3",
                                   P,
                                   S,
                                   distort_coeffs=D,
                                   intri_matrix=P,
                                   mirror_coeff=xi)
        S, P, D, xi = parse_mei_camera(intri3)
        calib.set_intrinsic_camera("cam4",
                                   P,
                                   S,
                                   distort_coeffs=D,
                                   intri_matrix=P,
                                   mirror_coeff=xi)

        calib.set_extrinsic(cam2pose["image_00"].reshape(3, 4),
                            frame_from="cam1")
        calib.set_extrinsic(cam2pose["image_01"].reshape(3, 4),
                            frame_from="cam2")
        calib.set_extrinsic(cam2pose["image_02"].reshape(3, 4),
                            frame_from="cam3")
        calib.set_extrinsic(cam2pose["image_03"].reshape(3, 4),
                            frame_from="cam4")
        calib.set_extrinsic(cam2velo.reshape(3, 4),
                            frame_from="cam1",
                            frame_to="velo")
        calib.set_extrinsic(sick2velo.reshape(3, 4),
                            frame_from="sick",
                            frame_to="velo")
        self._calibration = calib

    def calibration_data(self, idx):
        return self._calibration

    # XXX: fix the points (with large distance?) with bounding box
    def _parse_semantic_ply(self, ntqdm, seq: str, fname: Path, dynamic: bool,
                            result_path: Path, expand_frames: int):
        ''' match point cloud in aggregated semantic point clouds '''
        import pcl
        from filelock import FileLock
        from sklearn.neighbors import KDTree

        fstart, fend = fname.stem.split('_')
        fstart, fend = int(fstart), int(fend)
        fstart, fend = max(fstart - expand_frames,
                           0), min(fend + expand_frames,
                                   self.sequence_sizes[seq])
        frame_desc = "%s frames %d-%d" % ("dynamic" if dynamic else "static",
                                          fstart, fend)

        _logger.debug("loading semantics for " + frame_desc)
        semantics = pcl.io.load_ply(str(fname))
        if len(semantics) == 0:
            return

        # create fast semantic id to Kitti360Class id mapping
        idmap = np.zeros(max(id2label.keys()) + 1, dtype='u1')
        for i in range(len(idmap)):
            idmap[i] = id2label[i].name.value

        # load semantics for static points
        if dynamic:
            timestamps = semantics.to_ndarray()['timestamp'].flatten()
        else:
            tree = KDTree(semantics.xyz)

        # iterate all semantic label files
        for i in tqdm.trange(fstart,
                             fend,
                             desc=frame_desc,
                             position=ntqdm,
                             leave=False):
            # load semantics for dynamic points
            if dynamic:
                cur_semantics = semantics[timestamps == i]
                if len(cur_semantics) == 0:
                    continue
                tree = KDTree(cur_semantics.xyz)
            else:
                cur_semantics = semantics

            def update_semantics(cloud, name, idx):
                label_path = result_path / name / ("%010d.npz" % idx)
                dist_path = result_path / name / ("%010d.dist.npy" % idx)
                lock_path = FileLock(result_path / name / ("%010d.lock" % idx))

                # deal with empty (sick cloud)
                if len(cloud) == 0:
                    np.savez(label_path,
                             rgb=np.array([], dtype='u1').reshape(0, 3),
                             semantic=np.array([], dtype='u1'),
                             instance=np.array([], dtype='u2'),
                             visible=np.array([], dtype=bool))
                    np.save(dist_path, np.array([]))
                    return

                # match point cloud
                distance, sidx = tree.query(cloud, return_distance=True)
                selected = cur_semantics.to_ndarray()[sidx]
                distance = distance.flatten()

                # fetch labels from nearest points
                rgb = selected["rgb"].flatten().view('4u1')[:, :3]
                slabels = selected["semantic"].flatten()
                slabels = idmap[slabels]
                ilabels = selected["instance"].flatten().astype('u2')
                visible = selected["visible"].flatten()

                # update saved labels
                with lock_path:
                    if dist_path.exists():
                        old_distance = np.load(dist_path)
                        update_mask = (distance < old_distance)
                        distance = np.where(update_mask, distance,
                                            old_distance)

                        old_labels = np.load(label_path)
                        old_visible = np.unpackbits(old_labels["visible"],
                                                    count=len(cloud))
                        rgb = np.where(update_mask.reshape(-1, 1), rgb,
                                       old_labels['rgb'])
                        slabels = np.where(update_mask, slabels,
                                           old_labels["semantic"])
                        ilabels = np.where(update_mask, ilabels,
                                           old_labels["instance"])
                        visible = np.where(update_mask, visible, old_visible)

                    np.savez(label_path,
                             rgb=rgb,
                             semantic=slabels,
                             instance=ilabels,
                             visible=np.packbits(visible))
                    np.save(dist_path, distance)

            # load and transform velodyne points
            cloud = self.lidar_data((seq, i), names="velo", bypass=True)
            cloud = self._calibration.transform_points(cloud[:, :3],
                                                       frame_to="pose",
                                                       frame_from="velo")
            cloud = cloud.dot(
                self._poses_r[seq][i].as_matrix().T) + self._poses_t[seq][i]
            update_semantics(cloud, 'velodyne', i)

            # load and transform sick points
            for item in self.intermediate_data((seq, i),
                                               names="sick",
                                               ninter_frames=None,
                                               report_semantic=False):
                cloud = np.insert(item.data, 2, 0, axis=1)
                cloud = self._calibration.transform_points(cloud,
                                                           frame_to="pose",
                                                           frame_from="sick")
                cloud = cloud.dot(
                    item.pose.orientation.as_matrix().T) + item.pose.position
                update_semantics(cloud, 'sick', item.index)

    _semantic_dtypes = dict(rgb='3u1',
                            semantic='u1',
                            instance='u2',
                            visible='u1')

    def _preload_3dsemantics(self,
                             seq,
                             nworkers=7,
                             expand_frames=150,
                             stats_error=False):
        """
        This method will convert the combined semantic point cloud back in to frames

        expand_frames: number of expanded frames to be considered based on original sequence split
            Better painting results will be generated with larger expansion, but it will be slower
        """
        assert seq in self.sequence_ids

        if self.inzip:
            if (self.base_path / f"{seq}_semantics.zip").exists():
                return
            result_path = Path(tempfile.mkdtemp())
            data_path = Path(tempfile.mkdtemp())
        else:
            result_path = self.base_path / "data_3d_semantics" / seq
            data_path = self.base_path
            if (result_path / 'velodyne').exists():
                return
        velo_path = (result_path / 'velodyne')
        sick_path = (result_path / 'sick')
        velo_path.mkdir()
        sick_path.mkdir()

        try:
            if self.inzip:
                _logger.info("Extracting semantic labels of %s to %s...", seq,
                             data_path)
                with ZipFile(self.base_path /
                             "data_3d_semantics.zip") as archive:
                    files = [
                        info for info in archive.infolist()
                        if info.filename.startswith("data_3d_semantics/" +
                                                    seq) and not info.is_dir()
                    ]
                    for info in tqdm.tqdm(files,
                                          desc="Extracing semantic labels",
                                          leave=False):
                        archive.extract(info, data_path)

            _logger.info("Converting 3d semantic labels for sequence %s...",
                         seq)
            tstart = time.time()

            # load poses for aligning point cloud
            self._preload_poses(seq)

            pool = NumberPool(nworkers)
            for fspan in (data_path / "data_3d_semantics" / seq /
                          "static").glob("*.ply"):
                pool.apply_async(
                    self._parse_semantic_ply,
                    (seq, fspan, False, result_path, expand_frames))
            for fspan in (data_path / "data_3d_semantics" / seq /
                          "dynamic").glob("*.ply"):
                pool.apply_async(
                    self._parse_semantic_ply,
                    (seq, fspan, True, result_path, expand_frames))
            pool.close()
            pool.join()
            tend = time.time()
            _logger.info("Conversion finished, consumed time: %.4fs",
                         tend - tstart)

            # statistics on distance error
            if stats_error:
                total_points = unmatched_points = 0
                for f in tqdm.tqdm(list(velo_path.glob("*.dist.npy")),
                                   desc="Revisiting velodyne distance arrays",
                                   leave=False):
                    d = np.load(f)
                    total_points += len(d)
                    unmatched_points += np.sum(d > 5)
                _logger.debug("Velodyne unmatched ratio (distance > 5): %.2f",
                              unmatched_points / total_points * 100)

                total_points = unmatched_points = 0
                for f in tqdm.tqdm(list(sick_path.glob("*.dist.npy")),
                                   desc="Revisiting sick distance arrays",
                                   leave=False):
                    d = np.load(f)
                    total_points += len(d)
                    unmatched_points += np.sum(d > 5)
                _logger.debug("Sick unmatched ratio (distance > 5): %.2f",
                              unmatched_points / total_points * 100)

            _logger.info("Saving indexed semantic labels...")
            velo_files = ((f, 'velodyne') for f in velo_path.glob("*.npz"))
            sick_files = ((f, 'sick') for f in sick_path.glob("*.npz"))
            if self.inzip:
                # save results into zipfile
                with ZipFile(self.base_path / f"{seq}_semantics.zip",
                             "w",
                             compression=self.compression) as archive:
                    for f, sensor in chain(velo_files, sick_files):
                        labels = np.load(f)
                        name_out = f.stem + '.bin'
                        for key in labels:
                            archive.writestr(
                                f"data_3d_semantics/{seq}/{sensor}/{key}/{name_out}",
                                labels[key].tobytes())

            else:
                # save results into separate binary files
                for key in self._semantic_dtypes:
                    (result_path / 'velodyne' / key).mkdir()
                    (result_path / 'sick' / key).mkdir()

                for f, sensor in chain(velo_files, sick_files):
                    labels = np.load(f)
                    name_out = f.stem + '.bin'
                    for key in labels:
                        labels[key].tofile(result_path / sensor / key /
                                           name_out)
                    f.unlink()

        finally:
            # clean up
            if self.inzip:
                # remove temporary files
                shutil.rmtree(result_path)
                shutil.rmtree(data_path)
            else:
                # remove distance files
                for f in (list(velo_path.iterdir()) +
                          list(sick_path.iterdir())):
                    if f.suffix == ".npy" or f.suffix == ".lock":
                        f.unlink()
            _logger.debug("Conversion clean up finished!")

    @expand_idx
    def annotation_3dpoints(self, idx):
        seq_id, frame_idx = idx
        self._preload_3dsemantics(seq_id)

        fnames = {}
        for key in self._semantic_dtypes:
            fnames[key] = Path("data_3d_semantics", seq_id, "velodyne", key,
                               "%010d.bin" % frame_idx)
        if self._return_file_path:
            return edict({k: self.base_path / v for k, v in fnames.items()})

        data = edict()
        if self.inzip:
            with PatchedZipFile(self.base_path / f"{seq_id}_semantics.zip",
                                to_extract=[str(v)
                                            for v in fnames.values()]) as ar:
                for k, v in fnames.items():
                    data[k] = np.frombuffer(ar.read(str(v)),
                                            dtype=self._semantic_dtypes[k])
        else:
            for k, v in fnames.items():
                data[k] = np.fromfile(self.base_path / v,
                                      dtype=self._semantic_dtypes[k])

        data.visible = np.unpackbits(data.visible,
                                     count=len(data.semantic)).astype(bool)
        return data

    def annotation_2dpoints(self, idx):
        raise NotImplementedError()

    def _preload_timestamps(self, seq, name):
        if (seq, name) in self._timestamp_cache:
            return
        assert seq in self.sequence_ids

        folder, subfolder, _, archive = self.FRAME_PATH_MAP[name]
        fname = Path(seq, subfolder, "timestamps.txt")
        if self.inzip:
            with PatchedZipFile(self.base_path / archive,
                                to_extract=fname) as data:
                ts = load_timestamps(data, fname, formatted=True)
        else:
            ts = load_timestamps(self.base_path / folder,
                                 fname,
                                 formatted=True)

        self._timestamp_cache[(seq, name)] = ts.astype(int) // 1000

    @expand_idx
    def timestamp(self, idx, names="velo"):
        if names == "sick":
            raise NotImplementedError(
                "Indexing for sick points are unavailable yet!")

        seq_id, frame_idx = idx
        self._preload_timestamps(seq_id, names)

        return self._timestamp_cache[(seq_id, names)][frame_idx]

    def _preload_poses(self, seq):
        if seq in self._poses_idx:
            return
        assert seq in self.sequence_ids

        fname = Path("data_poses", seq, "poses.txt")
        if self.inzip:
            with PatchedZipFile(self.base_path / "data_poses.zip",
                                to_extract=fname) as data:
                plist = np.loadtxt(data.open(str(fname)))
        else:
            plist = np.loadtxt(self.base_path / fname)

        # do interpolation
        pose_indices = plist[:, 0].astype(int)
        pose_matrices = plist[:, 1:].reshape(-1, 3, 4)
        positions = pose_matrices[:, :, 3]
        rotations = Rotation.from_matrix(pose_matrices[:, :, :3])

        ts_frame = "velo"  # the frame used for timestamp extraction
        self._preload_timestamps(seq, ts_frame)
        timestamps = self._timestamp_cache[(seq, ts_frame)]

        fpos = interp1d(timestamps[pose_indices],
                        positions,
                        axis=0,
                        fill_value="extrapolate")
        positions = fpos(timestamps)
        frot = interp1d(timestamps[pose_indices],
                        rotations.as_rotvec(),
                        axis=0,
                        fill_value="extrapolate")
        rotations = frot(timestamps)

        self._poses_idx[seq] = set(pose_indices)
        self._poses_t[seq] = positions
        self._poses_r[seq] = Rotation.from_rotvec(rotations)

    @expand_idx
    def pose(self, idx):
        seq_id, frame_idx = idx

        self._preload_poses(seq_id)
        if frame_idx not in self._poses_idx[
                seq_id] and not self.interpolate_pose:
            return None

        return EgoPose(self._poses_t[seq_id][frame_idx],
                       self._poses_r[seq_id][frame_idx])

    @expand_idx_name(['sick'])
    def intermediate_data(self,
                          idx,
                          names='sick',
                          ninter_frames=None,
                          report_semantic=True):
        assert names == 'sick', "Only intermediate data for sick lidar is available in Kitti360!"
        seq_id, frame_idx = idx

        self._preload_timestamps(seq_id, names)
        if report_semantic:
            self._preload_3dsemantics(seq_id)

        # find the corresponding sick indices
        ts_frame = "velo"  # the frame used for timestamp extraction
        self._preload_timestamps(seq_id, ts_frame)
        key_ts_list = self._timestamp_cache[(seq_id, ts_frame)]
        key_ts_prev = key_ts_list[frame_idx - 1] if frame_idx != 0 else 0
        key_ts = key_ts_list[frame_idx]
        sick_ts_idxa = bisect_right(self._timestamp_cache[(seq_id, names)],
                                    key_ts_prev)
        sick_ts_idxb = bisect_right(self._timestamp_cache[(seq_id, names)],
                                    key_ts)

        # do pose interpolation
        if self.interpolate_pose:
            self._preload_poses(seq_id)
            fpos = interp1d(key_ts_list,
                            self._poses_t[seq_id],
                            axis=0,
                            fill_value="extrapolate")
            frot = interp1d(key_ts_list,
                            self._poses_r[seq_id].as_rotvec(),
                            axis=0,
                            fill_value="extrapolate")

        # load corresponding sick data
        sick_ts_list = self._timestamp_cache[(seq_id, names)]
        sick_ts_idx_list = list(range(sick_ts_idxa, sick_ts_idxb))
        if ninter_frames is not None:
            sick_ts_idx_list = sick_ts_idx_list[-ninter_frames:]
        result = []
        for sick_idx in sick_ts_idx_list:
            sick_ts = sick_ts_list[sick_idx]
            item = edict(index=sick_idx, timestamp=sick_ts)
            if self.interpolate_pose:
                position, rotation = fpos(sick_ts), frot(sick_ts)
                rotation = Rotation.from_rotvec(rotation)
                item.pose = EgoPose(position, rotation)

            item.file = Path(seq_id, "sick_points", "data",
                             '%010d.bin' % sick_idx)
            if report_semantic:
                for key in self._semantic_dtypes:
                    item[key] = Path("data_3d_semantics", seq_id, "sick", key,
                                     "%010d.bin" % sick_idx)
            result.append(item)

        if self.inzip:
            namelist = [item.file for item in result]
            with PatchedZipFile(self.base_path / f"{seq_id}_sick.zip",
                                to_extract=namelist) as source:
                for item in result:
                    item.data = load_sick_scan(source, item.pop("file"))

            if report_semantic:
                namelist = []  # gather all files to read in advance
                for item in result:
                    for key in self._semantic_dtypes:
                        namelist.append(item[key])

                with PatchedZipFile(self.base_path / f"{seq_id}_semantics.zip",
                                    to_extract=namelist) as source:
                    for item in result:
                        for key in self._semantic_dtypes:
                            item[key] = np.frombuffer(
                                source.read(str(item[key])),
                                dtype=self._semantic_dtypes[key])
                        item.visible = np.unpackbits(
                            item.visible, count=len(item.data)).astype(bool)
        else:
            for item in result:
                if not self._return_file_path:
                    item.data = load_sick_scan(self.base_path / "data_3d_raw",
                                               item.pop("file"))
                else:
                    item.file = self.base_path / "data_3d_raw" / item.file

            if report_semantic:
                for item in result:
                    if not self._return_file_path:
                        for key in self._semantic_dtypes:
                            item[key] = np.fromfile(
                                self.base_path / item[key],
                                dtype=self._semantic_dtypes[key])
                        item.visible = np.unpackbits(
                            item.visible, count=len(item.data)).astype(bool)
                    else:
                        for key in self._semantic_dtypes:
                            item[key] = self.base_path / item[key]

        return result
Пример #29
0
    def calculate_scores(self):
        """
        Function to calculate a score for each transcript, given the metrics derived
        with the calculate_metrics method and the scoring scheme provided in the JSON configuration.
        If any requirements have been specified, all transcripts which do not pass them
        will be assigned a score of 0 and subsequently ignored.
        Scores are rounded to the nearest integer.
        """

        if self.scores_calculated is True:
            return

        self.get_metrics()
        if not hasattr(self, "logger"):
            self.logger = None
            self.logger.setLevel("DEBUG")
        self.logger.debug("Calculating scores for {0}".format(self.id))

        self.scores = dict()
        for tid in self.transcripts:
            self.scores[tid] = dict()
            # Add the score for the transcript source
            self.scores[tid]["source_score"] = self.transcripts[tid].source_score

        if self.regressor is None:
            for param in self.json_conf["scoring"]:
                self._calculate_score(param)

            for tid in self.scores:
                self.transcripts[tid].scores = self.scores[tid].copy()

            for tid in self.transcripts:

                if tid in self.__orf_doubles:
                    del self.scores[tid]
                    continue
                self.transcripts[tid].score = sum(self.scores[tid].values())
                self.scores[tid]["score"] = self.transcripts[tid].score

        else:
            valid_metrics = self.regressor.metrics
            metric_rows = SortedDict()
            for tid, transcript in sorted(self.transcripts.items(), key=operator.itemgetter(0)):
                for param in valid_metrics:
                    self.scores[tid][param] = "NA"
                row = []
                for attr in valid_metrics:
                    val = getattr(transcript, attr)
                    if isinstance(val, bool):
                        if val:
                            val = 1
                        else:
                            val = 0
                    row.append(val)
                metric_rows[tid] = row
            # scores = SortedDict.fromkeys(metric_rows.keys())
            for pos, score in enumerate(self.regressor.predict(list(metric_rows.values()))):
                tid = list(metric_rows.keys())[pos]
                if tid in self.__orf_doubles:
                    del self.scores[tid]
                    continue
                self.scores[tid]["score"] = score
                self.transcripts[tid].score = score

        self.metric_lines_store = []
        for row in self.prepare_metrics():
            if row["tid"] in self.__orf_doubles:
                continue
            else:
                self.metric_lines_store.append(row)

        for doubled in self.__orf_doubles:
            for partial in self.__orf_doubles[doubled]:
                if partial in self.transcripts:
                    del self.transcripts[partial]

        self.scores_calculated = True
Пример #30
0
class BasePlotBuilder:

    title = ""
    x_label = "X"
    y_label = "Y"

    style_defaults = {
        # viewport padding
        "vp_padding_top": 0.02,
        "vp_padding_bottom": 0.02,
        "vp_padding_left": 0.02,
        "vp_padding_right": 0.02,
        # font size control
        # matplotlib FontProperties
        "font-large": None,
        "font": None,
        "font-small": None
    }

    def __init__(self, **kwargs):
        """
        Init the base plot. The accepted keyword arguments are style
        configuration parameters with defaults in
        BasePlotBuilder.style_defaults.
        """
        style = self._pop_style_kwargs(kwargs)
        super().__init__(**kwargs)
        fig, ax = self.make_axes()

        logger.debug("Initialize plot '%s'", self.title)

        self.fig = fig
        """The matplotlib figure"""

        self.ax = ax
        """The matplotlib axes"""

        self._patch_builders = []
        """
        List of tuples (dataset, [builders,...]) holding the
        registered patch builders for each dataset.
        """

        self._legend_handles = []
        """Legend handles created by the patch builders."""

        self._xticks = SortedDict()
        """X axis ticks returned by patch builders"""

        self._yticks = SortedDict()
        """Y axis ticks returned by patch builders"""

        self._view_box = Bbox.from_bounds(0, 0, 0, 0)
        """The viewport bounding box in data coordinates."""

        self._style = style
        """Style options."""

    def _pop_style_kwargs(self, kwargs):
        """
        Extract the style parameters from kwargs

        :return: dict with style parameters
        """
        return {
            k: kwargs.pop(k, self.style_defaults[k])
            for k in self.style_defaults.keys()
        }

    def _get_axes_kwargs(self):
        """
        Return the kwargs used to add the axes to the figure

        :return: dict
        """
        return {}

    def _get_figure_kwargs(self):
        """
        Return the kwargs used to create the plot figure

        :return: kwargs dict for :func:`pyplot.figure`
        """
        return {"figsize": (15, 10)}

    def _get_axes_rect(self):
        """
        Return the rect (in % of figure width and height)

        :return: list (4 floats, [left, bottom, width, height])
        """
        return [0.05, 0.15, 0.9, 0.8]

    def _get_legend_kwargs(self):
        """
        Return the kwargs used to add the legend to the figure

        :return: kwargs dict for :meth:`Axes.legend`
        """
        return {"handles": self._legend_handles, "prop": self._style["font"]}

    def _get_savefig_kwargs(self):
        """
        Return the kwargs for savefig() to control how
        the figure is saved to a file.

        :return: kwargs dict for :meth:`Figure.savefig`
        """
        return {}

    def _get_title_kwargs(self):
        """
        Return the kwargs for the axes.set_title call. This
        can be used to control the text properties of the title.

        :return: kwargs dict for :meth:`Axes.set_title`
        """
        if self._style["font"]:
            return {"fontproperties": self._style["font"]}
        return {}

    def _get_xlabels_kwargs(self):
        """
        Return the kwargs for the axes.set_xticklabels. This
        can be used to control the text properties of the labels.

        :return: kwargs dict for :meth:`Axes.set_xticklabels`
        """
        if self._style["font"]:
            return {"fontproperties": self._style["font"]}
        return {}

    def _get_ylabels_kwargs(self):
        """
        Return the kwargs for the axes.set_yticklabels. This
        can be used to control the text properties of the labels.

        :return: kwargs dict for :meth:`Axes.set_yticklabels`
        """
        if self._style["font"]:
            return {"fontproperties": self._style["font"]}
        return {}

    def _get_xlabel_kwargs(self):
        """
        Return the kwargs for the axes.set_xlabel.

        :return: kwargs dict for :meth:`Axes.set_xlabel`
        """
        if self._style["font"]:
            return {"fontproperties": self._style["font"]}
        return {}

    def _get_ylabel_kwargs(self):
        """
        Return the kwargs for the axes.set_ylabel.

        :return: kwargs dict for :meth:`Axes.set_ylabel`
        """
        if self._style["font"]:
            return {"fontproperties": self._style["font"]}
        return {}

    def make_axes(self):
        """
        Build the figure and axes for the plot

        :return: tuple containing the figure and the axes
        """
        fig = plt.figure(**self._get_figure_kwargs())
        rect = self._get_axes_rect()
        ax = fig.add_axes(rect, **self._get_axes_kwargs())
        return (fig, ax)

    def _dbg_repr_patch_builders(self):
        """
        Print a debug representation of the patch builders in the
        dict of patch builders in BasePlotBuilder.
        """
        pairs = map(
            lambda b: "dataset %s -> %s" % (b[0], list(map(str, b[1]))),
            self._patch_builders)
        return reduce(lambda p, a: "%s\n%s" % (a, p), pairs, "")


#    @profile

    def make_patches(self):
        """
        Build the patches from all the registered patch builders
        The advantage is that a single iteration of each dataset
        is performed, all the patch-builders are invoked on the item
        during the iteration. Special handling can be done in subclasses
        at the expense of performance by passing a reduced set of
        patch_builders to this method.
        """
        # XXX experimental replace start
        logger.debug("Make patches:\n%s", self._dbg_repr_patch_builders())
        for idx, (dataset, builders) in enumerate(self._patch_builders):
            with ProgressTimer(
                    "Inspect dataset [%d/%d]" %
                (idx + 1, len(self._patch_builders)), logger):
                for item in dataset:
                    for b in builders:
                        b.inspect(item)
        builders = list(chain(*map(itemgetter(1), self._patch_builders)))
        bboxes = []
        for idx, b in enumerate(builders):
            with ProgressTimer(
                    "Make patches [%d/%d]" % (idx + 1, len(builders)), logger):
                # grab all the patches from the builders
                b.get_patches(self.ax)
            with ProgressTimer(
                    "Fetch plot elements [%d/%d]" % (idx + 1, len(builders)),
                    logger):
                # grab the viewport from the builders
                bboxes.append(b.get_bbox())
                # grab the legend from the builders
                self._legend_handles.extend(b.get_legend(self._legend_handles))
                # grab the x and y ticks
                xticks = b.get_xticks()
                xlabels = b.get_xlabels()
                yticks = b.get_yticks()
                ylabels = b.get_ylabels()
                if xlabels is None:
                    xlabels = repeat(None)
                if ylabels is None:
                    ylabels = repeat(None)
            self._xticks.update(zip(xticks, xlabels))
            self._yticks.update(zip(yticks, ylabels))
        # free the builders as we have no use for them anymore
        self.clear_patch_builders()
        self._view_box = Bbox.union(bboxes)
        logger.debug("Plot viewport %s", self._view_box)
        logger.debug("Num ticks: x:%d y:%d", len(self._xticks),
                     len(self._yticks))
        logger.debug("Legend entries %s",
                     list(map(lambda h: h.get_label(), self._legend_handles)))

    def make_plot(self):
        """
        Set the plot labels, ticks, viewport and legend from the
        patch builders.
        """
        logger.debug("Make plot")
        # set global font properties
        if self._style["font"]:
            self.ax.tick_params(labelsize=self._style["font"].get_size())
        self.ax.set_title(self.title, **self._get_title_kwargs())
        self.ax.set_xlabel(self.x_label, **self._get_xlabel_kwargs())
        self.ax.set_ylabel(self.y_label, **self._get_ylabel_kwargs())
        # set viewport
        # grab the viewbox and make a bounding box with it.
        xmin = self._view_box.xmin * (1 - self._style["vp_padding_left"])
        xmax = self._view_box.xmax * (1 + self._style["vp_padding_right"])
        ymin = self._view_box.ymin * (1 - self._style["vp_padding_bottom"])
        ymax = self._view_box.ymax * (1 + self._style["vp_padding_top"])
        self.ax.set_xlim(xmin, xmax)
        self.ax.set_ylim(ymin, ymax)
        self.ax.legend(**self._get_legend_kwargs())
        if self._xticks:
            self.ax.set_xticks(self._xticks.keys())
            self.ax.set_xticklabels(self._xticks.values(),
                                    **self._get_xlabels_kwargs())
        if self._yticks:
            self.ax.set_yticks(self._yticks.keys())
            self.ax.set_yticklabels(self._yticks.values(),
                                    **self._get_ylabels_kwargs())

    def process(self, out_file=None, show=True):
        """
        Produce the plot and display it or write it to a file

        :param out_file: output file path
        :type out_file: str
        :param show: show the plot in an interactive window
        :type show: bool
        """
        with ProgressTimer("Plot builder processing", logger):
            self.make_patches()
            self.make_plot()
            if out_file:
                self.fig.savefig(out_file, **self._get_savefig_kwargs())
            if show:
                # the fig.show() method does not enter the backend main loop
                # self.fig.show()
                plt.show()

    def register_patch_builder(self, dataset, builder):
        """
        Add a patch builder for a dataset

        :param dataset: dataset object
        :type dataset: iterable
        :param builder: the patch builder for items of the dataset
        :type builder: :class:`PatchBuilder`
        """
        builder.set_style(self._style)
        for entry in self._patch_builders:
            entry_dataset, entry_builders = entry
            if entry_dataset == dataset:
                entry_builders.append(builder)
                break
        else:
            self._patch_builders.append((dataset, [builder]))

    def clear_patch_builders(self):
        """Remove all registered patch builders."""
        self._patch_builders = []
class OrderTree(object):
    '''A red-black tree used to store OrderLists in price order

    The exchange will be using the OrderTree to hold bid and ask data (one OrderTree for each side).
    Keeping the information in a red black tree makes it easier/faster to detect a match.
    '''
    def __init__(self):
        self.price_map = SortedDict(
        )  # Dictionary containing price : OrderList object
        self.prices = self.price_map.keys(
        )  # sorted by prices so prices are keys
        self.order_map = {}  # Dictionary containing order_id : Order object
        self.volume = 0  # Contains total quantity from all Orders in tree
        self.num_orders = 0  # Contains count of Orders in tree
        self.depth = 0  # Number of different prices in tree (http://en.wikipedia.org/wiki/Order_book_(trading)#Book_depth)

    def __len__(self):
        return len(self.order_map)

    def get_price_list(self, price):
        return self.price_map[price]

    def get_order(self, order_id):
        return self.order_map[order_id]

    def create_price(self, price):
        self.depth += 1  # Add a price depth level to the tree
        new_list = OrderList()
        self.price_map[price] = new_list

    def remove_price(self, price):
        self.depth -= 1  # Remove a price depth level
        del self.price_map[price]

    def price_exists(self, price):
        return price in self.price_map

    def order_exists(self, order):
        return order in self.order_map

    def insert_order(self, quote):
        if self.order_exists(quote['order_id']):
            self.remove_order_by_id(quote['order_id'])
        self.num_orders += 1
        if quote['price'] not in self.price_map:
            self.create_price(
                quote['price']
            )  # If price not in Price Map, create a node in RBtree
        order = Order(quote, self.price_map[quote['price']])  # Create an order
        self.price_map[order.price].append_order(
            order)  # Add the order to the OrderList in Price Map
        self.order_map[order.order_id] = order
        self.volume += order.quantity

    def update_order(self, order_update):
        order = self.order_map[order_update['order_id']]
        original_quantity = order.quantity
        if order_update['price'] != order.price:
            # Price changed. Remove order and update tree.
            order_list = self.price_map[order.price]
            order_list.remove_order(order)
            if len(
                    order_list
            ) == 0:  # If there is nothing else in the OrderList, remove the price from RBtree
                self.remove_price(order.price)
            self.insert_order(order_update)
        else:
            # Quantity changed. Price is the same.
            order.update_quantity(order_update['quantity'],
                                  order_update['timestamp'])
        self.volume += order.quantity - original_quantity

    def remove_order_by_id(self, order_id):
        self.num_orders -= 1
        order = self.order_map[order_id]
        self.volume -= order.quantity
        order.order_list.remove_order(order)
        if len(order.order_list) == 0:
            self.remove_price(order.price)
        del self.order_map[order_id]

    def max_price(self):
        if self.depth > 0:
            return self.prices[-1]
        else:
            return None

    def min_price(self):
        if self.depth > 0:
            return self.prices[0]
        else:
            return None

    def max_price_list(self):
        if self.depth > 0:
            return self.get_price_list(self.max_price())
        else:
            return None

    def min_price_list(self):
        if self.depth > 0:
            return self.get_price_list(self.min_price())
        else:
            return None
Пример #32
0
class Replica(HasActionQueue, MessageProcessor):
    def __init__(self, node: 'plenum.server.node.Node', instId: int,
                 isMaster: bool = False):
        """
        Create a new replica.

        :param node: Node on which this replica is located
        :param instId: the id of the protocol instance the replica belongs to
        :param isMaster: is this a replica of the master protocol instance
        """
        super().__init__()
        self.stats = Stats(TPCStat)

        self.config = getConfig()

        routerArgs = [(ReqDigest, self._preProcessReqDigest)]

        for r in [PrePrepare, Prepare, Commit]:
            routerArgs.append((r, self.processThreePhaseMsg))

        routerArgs.append((Checkpoint, self.processCheckpoint))
        routerArgs.append((ThreePCState, self.process3PhaseState))

        self.inBoxRouter = Router(*routerArgs)

        self.threePhaseRouter = Router(
                (PrePrepare, self.processPrePrepare),
                (Prepare, self.processPrepare),
                (Commit, self.processCommit)
        )

        self.node = node
        self.instId = instId

        self.name = self.generateName(node.name, self.instId)

        self.outBox = deque()
        """
        This queue is used by the replica to send messages to its node. Replica
        puts messages that are consumed by its node
        """

        self.inBox = deque()
        """
        This queue is used by the replica to receive messages from its node.
        Node puts messages that are consumed by the replica
        """

        self.inBoxStash = deque()
        """
        If messages need to go back on the queue, they go here temporarily and
        are put back on the queue on a state change
        """

        self.isMaster = isMaster

        # Indicates name of the primary replica of this protocol instance.
        # None in case the replica does not know who the primary of the
        # instance is
        self._primaryName = None    # type: Optional[str]

        # Requests waiting to be processed once the replica is able to decide
        # whether it is primary or not
        self.postElectionMsgs = deque()

        # PRE-PREPAREs that are waiting to be processed but do not have the
        # corresponding request digest. Happens when replica has not been
        # forwarded the request by the node but is getting 3 phase messages.
        # The value is a list since a malicious entry might send PRE-PREPARE
        # with a different digest and since we dont have the request finalised,
        # we store all PRE-PPREPARES
        self.prePreparesPendingReqDigest = {}   # type: Dict[Tuple[str, int], List]

        # PREPAREs that are stored by non primary replica for which it has not
        #  got any PRE-PREPARE. Dictionary that stores a tuple of view no and
        #  prepare sequence number as key and a deque of PREPAREs as value.
        # This deque is attempted to be flushed on receiving every
        # PRE-PREPARE request.
        self.preparesWaitingForPrePrepare = {}
        # type: Dict[Tuple[int, int], deque]

        # COMMITs that are stored for which there are no PRE-PREPARE or PREPARE
        # received
        self.commitsWaitingForPrepare = {}
        # type: Dict[Tuple[int, int], deque]

        # Dictionary of sent PRE-PREPARE that are stored by primary replica
        # which it has broadcasted to all other non primary replicas
        # Key of dictionary is a 2 element tuple with elements viewNo,
        # pre-prepare seqNo and value is a tuple of Request Digest and time
        self.sentPrePrepares = {}
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], float]]

        # Dictionary of received PRE-PREPAREs. Key of dictionary is a 2
        # element tuple with elements viewNo, pre-prepare seqNo and value is
        # a tuple of Request Digest and time
        self.prePrepares = {}
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], float]]

        # Dictionary of received Prepare requests. Key of dictionary is a 2
        # element tuple with elements viewNo, seqNo and value is a 2 element
        # tuple containing request digest and set of sender node names(sender
        # replica names in case of multiple protocol instances)
        # (viewNo, seqNo) -> ((identifier, reqId), {senders})
        self.prepares = Prepares()
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], Set[str]]]

        self.commits = Commits()    # type: Dict[Tuple[int, int],
        # Tuple[Tuple[str, int], Set[str]]]

        # Set of tuples to keep track of ordered requests. Each tuple is
        # (viewNo, ppSeqNo)
        self.ordered = OrderedSet()        # type: OrderedSet[Tuple[int, int]]

        # Dictionary to keep track of the which replica was primary during each
        # view. Key is the view no and value is the name of the primary
        # replica during that view
        self.primaryNames = {}  # type: Dict[int, str]

        # Holds msgs that are for later views
        self.threePhaseMsgsForLaterView = deque()
        # type: deque[(ThreePhaseMsg, str)]

        # Holds tuple of view no and prepare seq no of 3-phase messages it
        # received while it was not participating
        self.stashingWhileCatchingUp = set()       # type: Set[Tuple]

        # Commits which are not being ordered since commits with lower view
        # numbers and sequence numbers have not been ordered yet. Key is the
        # viewNo and value a map of pre-prepare sequence number to commit
        self.stashedCommitsForOrdering = {}         # type: Dict[int,
        # Dict[int, Commit]]

        self.checkpoints = SortedDict(lambda k: k[0])

        self.stashingWhileOutsideWaterMarks = deque()

        # Low water mark
        self._h = 0              # type: int

        # High water mark
        self.H = self._h + self.config.LOG_SIZE   # type: int

        self.lastPrePrepareSeqNo = self.h  # type: int

    @property
    def h(self) -> int:
        return self._h

    @h.setter
    def h(self, n):
        self._h = n
        self.H = self._h + self.config.LOG_SIZE

    @property
    def requests(self):
        return self.node.requests

    def shouldParticipate(self, viewNo: int, ppSeqNo: int):
        # Replica should only participating in the consensus process and the
        # replica did not stash any of this request's 3-phase request
        return self.node.isParticipating and (viewNo, ppSeqNo) \
                                             not in self.stashingWhileCatchingUp

    @staticmethod
    def generateName(nodeName: str, instId: int):
        """
        Create and return the name for a replica using its nodeName and
        instanceId.
         Ex: Alpha:1
        """
        return "{}:{}".format(nodeName, instId)

    @staticmethod
    def getNodeName(replicaName: str):
        return replicaName.split(":")[0]

    @property
    def isPrimary(self):
        """
        Is this node primary?

        :return: True if this node is primary, False otherwise
        """
        return self._primaryName == self.name if self._primaryName is not None \
            else None

    @property
    def primaryName(self):
        """
        Name of the primary replica of this replica's instance

        :return: Returns name if primary is known, None otherwise
        """
        return self._primaryName

    @primaryName.setter
    def primaryName(self, value: Optional[str]) -> None:
        """
        Set the value of isPrimary.

        :param value: the value to set isPrimary to
        """
        if not value == self._primaryName:
            self._primaryName = value
            self.primaryNames[self.viewNo] = value
            logger.debug("{} setting primaryName for view no {} to: {}".
                         format(self, self.viewNo, value))
            logger.debug("{}'s primaryNames for views are: {}".
                         format(self, self.primaryNames))
            self._stateChanged()

    def _stateChanged(self):
        """
        A series of actions to be performed when the state of this replica
        changes.

        - UnstashInBox (see _unstashInBox)
        """
        self._unstashInBox()
        if self.isPrimary is not None:
            # TODO handle suspicion exceptions here
            self.process3PhaseReqsQueue()
            # TODO handle suspicion exceptions here
            try:
                self.processPostElectionMsgs()
            except SuspiciousNode as ex:
                self.outBox.append(ex)
                self.discard(ex.msg, ex.reason, logger.warning)

    def _stashInBox(self, msg):
        """
        Stash the specified message into the inBoxStash of this replica.

        :param msg: the message to stash
        """
        self.inBoxStash.append(msg)

    def _unstashInBox(self):
        """
        Append the inBoxStash to the right of the inBox.
        """
        self.inBox.extend(self.inBoxStash)
        self.inBoxStash.clear()

    def __repr__(self):
        return self.name

    @property
    def f(self) -> int:
        """
        Return the number of Byzantine Failures that can be tolerated by this
        system. Equal to (N - 1)/3, where N is the number of nodes in the
        system.
        """
        return self.node.f

    @property
    def viewNo(self):
        """
        Return the current view number of this replica.
        """
        return self.node.viewNo

    def isPrimaryInView(self, viewNo: int) -> Optional[bool]:
        """
        Return whether a primary has been selected for this view number.
        """
        return self.primaryNames[viewNo] == self.name

    def isMsgForLaterView(self, msg):
        """
        Return whether this request's view number is greater than the current
        view number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo > self.viewNo

    def isMsgForCurrentView(self, msg):
        """
        Return whether this request's view number is equal to the current view
        number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo == self.viewNo

    def isMsgForPrevView(self, msg):
        """
        Return whether this request's view number is less than the current view
        number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo < self.viewNo

    def isPrimaryForMsg(self, msg) -> Optional[bool]:
        """
        Return whether this replica is primary if the request's view number is
        equal this replica's view number and primary has been selected for
        the current view.
        Return None otherwise.

        :param msg: message
        """
        if self.isMsgForLaterView(msg):
            self.discard(msg,
                         "Cannot get primary status for a request for a later "
                         "view {}. Request is {}".format(self.viewNo, msg),
                         logger.error)
        else:
            return self.isPrimary if self.isMsgForCurrentView(msg) \
                else self.isPrimaryInView(msg.viewNo)

    def isMsgFromPrimary(self, msg, sender: str) -> bool:
        """
        Return whether this message was from primary replica
        :param msg:
        :param sender:
        :return:
        """
        if self.isMsgForLaterView(msg):
            logger.error("{} cannot get primary for a request for a later "
                         "view. Request is {}".format(self, msg))
        else:
            return self.primaryName == sender if self.isMsgForCurrentView(
                msg) else self.primaryNames[msg.viewNo] == sender

    def _preProcessReqDigest(self, rd: ReqDigest) -> None:
        """
        Process request digest if this replica is not a primary, otherwise stash
        the message into the inBox.

        :param rd: the client Request Digest
        """
        if self.isPrimary is not None:
            self.processReqDigest(rd)
        else:
            logger.debug("{} stashing request digest {} since it does not know "
                         "its primary status".
                         format(self, (rd.identifier, rd.reqId)))
            self._stashInBox(rd)

    def serviceQueues(self, limit=None):
        """
        Process `limit` number of messages in the inBox.

        :param limit: the maximum number of messages to process
        :return: the number of messages successfully processed
        """
        # TODO should handle SuspiciousNode here
        r = self.inBoxRouter.handleAllSync(self.inBox, limit)
        r += self._serviceActions()
        return r
        # Messages that can be processed right now needs to be added back to the
        # queue. They might be able to be processed later

    def processPostElectionMsgs(self):
        """
        Process messages waiting for the election of a primary replica to
        complete.
        """
        while self.postElectionMsgs:
            msg = self.postElectionMsgs.popleft()
            logger.debug("{} processing pended msg {}".format(self, msg))
            self.dispatchThreePhaseMsg(*msg)

    def process3PhaseReqsQueue(self):
        """
        Process the 3 phase requests from the queue whose view number is equal
        to the current view number of this replica.
        """
        unprocessed = deque()
        while self.threePhaseMsgsForLaterView:
            request, sender = self.threePhaseMsgsForLaterView.popleft()
            logger.debug("{} processing pended 3 phase request: {}"
                         .format(self, request))
            # If the request is for a later view dont try to process it but add
            # it back to the queue.
            if self.isMsgForLaterView(request):
                unprocessed.append((request, sender))
            else:
                self.processThreePhaseMsg(request, sender)
        self.threePhaseMsgsForLaterView = unprocessed

    @property
    def quorum(self) -> int:
        r"""
        Return the quorum of this RBFT system. Equal to :math:`2f + 1`.
        Return None if `f` is not yet determined.
        """
        return self.node.quorum

    def dispatchThreePhaseMsg(self, msg: ThreePhaseMsg, sender: str) -> Any:
        """
        Create a three phase request to be handled by the threePhaseRouter.

        :param msg: the ThreePhaseMsg to dispatch
        :param sender: the name of the node that sent this request
        """
        senderRep = self.generateName(sender, self.instId)
        if self.isPpSeqNoAcceptable(msg.ppSeqNo):
            try:
                self.threePhaseRouter.handleSync((msg, senderRep))
            except SuspiciousNode as ex:
                self.node.reportSuspiciousNodeEx(ex)
        else:
            logger.debug("{} stashing 3 phase message {} since ppSeqNo {} is "
                         "not between {} and {}".
                         format(self, msg, msg.ppSeqNo, self.h, self.H))
            self.stashingWhileOutsideWaterMarks.append((msg, sender))

    def processReqDigest(self, rd: ReqDigest):
        """
        Process a request digest. Works only if this replica has decided its
        primary status.

        :param rd: the client request digest to process
        """
        self.stats.inc(TPCStat.ReqDigestRcvd)
        if self.isPrimary is False:
            self.dequeuePrePrepare(rd.identifier, rd.reqId)
        else:
            self.doPrePrepare(rd)

    def processThreePhaseMsg(self, msg: ThreePhaseMsg, sender: str):
        """
        Process a 3-phase (pre-prepare, prepare and commit) request.
        Dispatch the request only if primary has already been decided, otherwise
        stash it.

        :param msg: the Three Phase message, one of PRE-PREPARE, PREPARE,
            COMMIT
        :param sender: name of the node that sent this message
        """
        # Can only proceed further if it knows whether its primary or not
        if self.isMsgForLaterView(msg):
            self.threePhaseMsgsForLaterView.append((msg, sender))
            logger.debug("{} pended received 3 phase request for a later view: "
                         "{}".format(self, msg))
        else:
            if self.isPrimary is None:
                self.postElectionMsgs.append((msg, sender))
                logger.debug("Replica {} pended request {} from {}".
                             format(self, msg, sender))
            else:
                self.dispatchThreePhaseMsg(msg, sender)

    def processPrePrepare(self, pp: PrePrepare, sender: str):
        """
        Validate and process the PRE-PREPARE specified.
        If validation is successful, create a PREPARE and broadcast it.

        :param pp: a prePrepareRequest
        :param sender: name of the node that sent this message
        """
        key = (pp.viewNo, pp.ppSeqNo)
        logger.debug("{} Receiving PRE-PREPARE{} at {} from {}".
                     format(self, key, time.perf_counter(), sender))
        if self.canProcessPrePrepare(pp, sender):
            if not self.node.isParticipating:
                self.stashingWhileCatchingUp.add(key)
            self.addToPrePrepares(pp)
            logger.info("{} processed incoming PRE-PREPARE{}".
                        format(self, key))

    def tryPrepare(self, pp: PrePrepare):
        """
        Try to send the Prepare message if the PrePrepare message is ready to
        be passed into the Prepare phase.
        """
        if self.canSendPrepare(pp):
            self.doPrepare(pp)
        else:
            logger.debug("{} cannot send PREPARE".format(self))

    def processPrepare(self, prepare: Prepare, sender: str) -> None:
        """
        Validate and process the PREPARE specified.
        If validation is successful, create a COMMIT and broadcast it.

        :param prepare: a PREPARE msg
        :param sender: name of the node that sent the PREPARE
        """
        # TODO move this try/except up higher
        logger.debug("{} received PREPARE{} from {}".
                     format(self, (prepare.viewNo, prepare.ppSeqNo), sender))
        try:
            if self.isValidPrepare(prepare, sender):
                self.addToPrepares(prepare, sender)
                self.stats.inc(TPCStat.PrepareRcvd)
                logger.debug("{} processed incoming PREPARE {}".
                             format(self, (prepare.viewNo, prepare.ppSeqNo)))
            else:
                # TODO let's have isValidPrepare throw an exception that gets
                # handled and possibly logged higher
                logger.warning("{} cannot process incoming PREPARE".
                               format(self))
        except SuspiciousNode as ex:
            self.node.reportSuspiciousNodeEx(ex)

    def processCommit(self, commit: Commit, sender: str) -> None:
        """
        Validate and process the COMMIT specified.
        If validation is successful, return the message to the node.

        :param commit: an incoming COMMIT message
        :param sender: name of the node that sent the COMMIT
        """
        logger.debug("{} received COMMIT {} from {}".
                     format(self, commit, sender))
        if self.isValidCommit(commit, sender):
            self.stats.inc(TPCStat.CommitRcvd)
            self.addToCommits(commit, sender)
            logger.debug("{} processed incoming COMMIT{}".
                         format(self, (commit.viewNo, commit.ppSeqNo)))

    def tryCommit(self, prepare: Prepare):
        """
        Try to commit if the Prepare message is ready to be passed into the
        commit phase.
        """
        if self.canCommit(prepare):
            self.doCommit(prepare)
        else:
            logger.debug("{} not yet able to send COMMIT".format(self))

    def tryOrder(self, commit: Commit):
        """
        Try to order if the Commit message is ready to be ordered.
        """
        canOrder, reason = self.canOrder(commit)
        if canOrder:
            logger.debug("{} returning request to node".format(self))
            self.tryOrdering(commit)
        else:
            logger.trace("{} cannot return request to node: {}".
                         format(self, reason))

    def doPrePrepare(self, reqDigest: ReqDigest) -> None:
        """
        Broadcast a PRE-PREPARE to all the replicas.

        :param reqDigest: a tuple with elements identifier, reqId, and digest
        """
        if not self.node.isParticipating:
            logger.error("Non participating node is attempting PRE-PREPARE. "
                         "This should not happen.")
            return

        if self.lastPrePrepareSeqNo == self.H:
            logger.debug("{} stashing PRE-PREPARE {} since outside greater "
                         "than high water mark {}".
                         format(self, (self.viewNo, self.lastPrePrepareSeqNo+1),
                                self.H))
            self.stashingWhileOutsideWaterMarks.append(reqDigest)
            return
        self.lastPrePrepareSeqNo += 1
        tm = time.time()*1000
        logger.debug("{} Sending PRE-PREPARE {} at {}".
                     format(self, (self.viewNo, self.lastPrePrepareSeqNo),
                            time.perf_counter()))
        prePrepareReq = PrePrepare(self.instId,
                                   self.viewNo,
                                   self.lastPrePrepareSeqNo,
                                   *reqDigest,
                                   tm)
        self.sentPrePrepares[self.viewNo, self.lastPrePrepareSeqNo] = (reqDigest.key,
                                                                       tm)
        self.send(prePrepareReq, TPCStat.PrePrepareSent)

    def doPrepare(self, pp: PrePrepare):
        logger.debug("{} Sending PREPARE {} at {}".
                     format(self, (pp.viewNo, pp.ppSeqNo), time.perf_counter()))
        prepare = Prepare(self.instId,
                          pp.viewNo,
                          pp.ppSeqNo,
                          pp.digest,
                          pp.ppTime)
        self.send(prepare, TPCStat.PrepareSent)
        self.addToPrepares(prepare, self.name)

    def doCommit(self, p: Prepare):
        """
        Create a commit message from the given Prepare message and trigger the
        commit phase
        :param p: the prepare message
        """
        logger.debug("{} Sending COMMIT{} at {}".
                     format(self, (p.viewNo, p.ppSeqNo), time.perf_counter()))
        commit = Commit(self.instId,
                        p.viewNo,
                        p.ppSeqNo,
                        p.digest,
                        p.ppTime)
        self.send(commit, TPCStat.CommitSent)
        self.addToCommits(commit, self.name)

    def canProcessPrePrepare(self, pp: PrePrepare, sender: str) -> bool:
        """
        Decide whether this replica is eligible to process a PRE-PREPARE,
        based on the following criteria:

        - this replica is non-primary replica
        - the request isn't in its list of received PRE-PREPAREs
        - the request is waiting to for PRE-PREPARE and the digest value matches

        :param pp: a PRE-PREPARE msg to process
        :param sender: the name of the node that sent the PRE-PREPARE msg
        :return: True if processing is allowed, False otherwise
        """
        # TODO: Check whether it is rejecting PRE-PREPARE from previous view
        # PRE-PREPARE should not be sent from non primary
        if not self.isMsgFromPrimary(pp, sender):
            raise SuspiciousNode(sender, Suspicions.PPR_FRM_NON_PRIMARY, pp)

        # A PRE-PREPARE is being sent to primary
        if self.isPrimaryForMsg(pp) is True:
            raise SuspiciousNode(sender, Suspicions.PPR_TO_PRIMARY, pp)

        # A PRE-PREPARE is sent that has already been received
        if (pp.viewNo, pp.ppSeqNo) in self.prePrepares:
            raise SuspiciousNode(sender, Suspicions.DUPLICATE_PPR_SENT, pp)

        key = (pp.identifier, pp.reqId)
        if not self.requests.isFinalised(key):
            self.enqueuePrePrepare(pp, sender)
            return False

        # A PRE-PREPARE is sent that does not match request digest
        if self.requests.digest(key) != pp.digest:
            raise SuspiciousNode(sender, Suspicions.PPR_DIGEST_WRONG, pp)

        return True

    def addToPrePrepares(self, pp: PrePrepare) -> None:
        """
        Add the specified PRE-PREPARE to this replica's list of received
        PRE-PREPAREs.

        :param pp: the PRE-PREPARE to add to the list
        """
        key = (pp.viewNo, pp.ppSeqNo)
        self.prePrepares[key] = \
            ((pp.identifier, pp.reqId), pp.ppTime)
        self.dequeuePrepares(*key)
        self.dequeueCommits(*key)
        self.stats.inc(TPCStat.PrePrepareRcvd)
        self.tryPrepare(pp)

    def hasPrepared(self, request) -> bool:
        return self.prepares.hasPrepareFrom(request, self.name)

    def canSendPrepare(self, request) -> bool:
        """
        Return whether the request identified by (identifier, requestId) can
        proceed to the Prepare step.

        :param request: any object with identifier and requestId attributes
        """
        return self.shouldParticipate(request.viewNo, request.ppSeqNo) \
            and not self.hasPrepared(request) \
            and self.requests.isFinalised((request.identifier,
                                           request.reqId))

    def isValidPrepare(self, prepare: Prepare, sender: str) -> bool:
        """
        Return whether the PREPARE specified is valid.

        :param prepare: the PREPARE to validate
        :param sender: the name of the node that sent the PREPARE
        :return: True if PREPARE is valid, False otherwise
        """
        key = (prepare.viewNo, prepare.ppSeqNo)
        primaryStatus = self.isPrimaryForMsg(prepare)

        ppReqs = self.sentPrePrepares if primaryStatus else self.prePrepares

        # If a non primary replica and receiving a PREPARE request before a
        # PRE-PREPARE request, then proceed

        # PREPARE should not be sent from primary
        if self.isMsgFromPrimary(prepare, sender):
            raise SuspiciousNode(sender, Suspicions.PR_FRM_PRIMARY, prepare)

        # If non primary replica
        if primaryStatus is False:
            if self.prepares.hasPrepareFrom(prepare, sender):
                raise SuspiciousNode(sender, Suspicions.DUPLICATE_PR_SENT, prepare)
            # If PRE-PREPARE not received for the PREPARE, might be slow network
            if key not in ppReqs:
                self.enqueuePrepare(prepare, sender)
                return False
            elif prepare.digest != self.requests.digest(ppReqs[key][0]):
                raise SuspiciousNode(sender, Suspicions.PR_DIGEST_WRONG, prepare)
            elif prepare.ppTime != ppReqs[key][1]:
                raise SuspiciousNode(sender, Suspicions.PR_TIME_WRONG,
                                     prepare)
            else:
                return True
        # If primary replica
        else:
            if self.prepares.hasPrepareFrom(prepare, sender):
                raise SuspiciousNode(sender, Suspicions.DUPLICATE_PR_SENT, prepare)
            # If PRE-PREPARE was not sent for this PREPARE, certainly
            # malicious behavior
            elif key not in ppReqs:
                raise SuspiciousNode(sender, Suspicions.UNKNOWN_PR_SENT, prepare)
            elif prepare.digest != self.requests.digest(ppReqs[key][0]):
                raise SuspiciousNode(sender, Suspicions.PR_DIGEST_WRONG, prepare)
            elif prepare.ppTime != ppReqs[key][1]:
                raise SuspiciousNode(sender, Suspicions.PR_TIME_WRONG,
                                     prepare)
            else:
                return True

    def addToPrepares(self, prepare: Prepare, sender: str):
        self.prepares.addVote(prepare, sender)
        self.tryCommit(prepare)

    def hasCommitted(self, request) -> bool:
        return self.commits.hasCommitFrom(ThreePhaseKey(
            request.viewNo, request.ppSeqNo), self.name)

    def canCommit(self, prepare: Prepare) -> bool:
        """
        Return whether the specified PREPARE can proceed to the Commit
        step.

        Decision criteria:

        - If this replica has got just 2f PREPARE requests then commit request.
        - If less than 2f PREPARE requests then probably there's no consensus on
            the request; don't commit
        - If more than 2f then already sent COMMIT; don't commit

        :param prepare: the PREPARE
        """
        return self.shouldParticipate(prepare.viewNo, prepare.ppSeqNo) and \
            self.prepares.hasQuorum(prepare, self.f) and \
            not self.hasCommitted(prepare)

    def isValidCommit(self, commit: Commit, sender: str) -> bool:
        """
        Return whether the COMMIT specified is valid.

        :param commit: the COMMIT to validate
        :return: True if `request` is valid, False otherwise
        """
        primaryStatus = self.isPrimaryForMsg(commit)
        ppReqs = self.sentPrePrepares if primaryStatus else self.prePrepares
        key = (commit.viewNo, commit.ppSeqNo)
        if key not in ppReqs:
            self.enqueueCommit(commit, sender)
            return False

        if (key not in self.prepares and
                key not in self.preparesWaitingForPrePrepare):
            logger.debug("{} rejecting COMMIT{} due to lack of prepares".
                         format(self, key))
            # raise SuspiciousNode(sender, Suspicions.UNKNOWN_CM_SENT, commit)
            return False
        elif self.commits.hasCommitFrom(commit, sender):
            raise SuspiciousNode(sender, Suspicions.DUPLICATE_CM_SENT, commit)
        elif commit.digest != self.getDigestFor3PhaseKey(ThreePhaseKey(*key)):
            raise SuspiciousNode(sender, Suspicions.CM_DIGEST_WRONG, commit)
        elif key in ppReqs and commit.ppTime != ppReqs[key][1]:
            raise SuspiciousNode(sender, Suspicions.CM_TIME_WRONG,
                                 commit)
        else:
            return True

    def addToCommits(self, commit: Commit, sender: str):
        """
        Add the specified COMMIT to this replica's list of received
        commit requests.

        :param commit: the COMMIT to add to the list
        :param sender: the name of the node that sent the COMMIT
        """
        self.commits.addVote(commit, sender)
        self.tryOrder(commit)

    def hasOrdered(self, viewNo, ppSeqNo) -> bool:
        return (viewNo, ppSeqNo) in self.ordered

    def canOrder(self, commit: Commit) -> Tuple[bool, Optional[str]]:
        """
        Return whether the specified commitRequest can be returned to the node.

        Decision criteria:

        - If have got just 2f+1 Commit requests then return request to node
        - If less than 2f+1 of commit requests then probably don't have
            consensus on the request; don't return request to node
        - If more than 2f+1 then already returned to node; don't return request
            to node

        :param commit: the COMMIT
        """
        if not self.commits.hasQuorum(commit, self.f):
            return False, "no quorum: {} commits where f is {}".\
                          format(commit, self.f)

        if self.hasOrdered(commit.viewNo, commit.ppSeqNo):
            return False, "already ordered"

        if not self.isNextInOrdering(commit):
            viewNo, ppSeqNo = commit.viewNo, commit.ppSeqNo
            if viewNo not in self.stashedCommitsForOrdering:
                self.stashedCommitsForOrdering[viewNo] = {}
            self.stashedCommitsForOrdering[viewNo][ppSeqNo] = commit
            # self._schedule(self.orderStashedCommits, 2)
            self.startRepeating(self.orderStashedCommits, 2)
            return False, "stashing {} since out of order".\
                format(commit)

        return True, None

    def isNextInOrdering(self, commit: Commit):
        viewNo, ppSeqNo = commit.viewNo, commit.ppSeqNo
        if self.ordered and self.ordered[-1] == (viewNo, ppSeqNo-1):
            return True
        for (v, p) in self.commits:
            if v < viewNo:
                # Have commits from previous view that are unordered.
                # TODO: Question: would commits be always ordered, what if
                # some are never ordered and its fine, go to PBFT.
                return False
            if v == viewNo and p < ppSeqNo and (v, p) not in self.ordered:
                # If unordered commits are found with lower ppSeqNo then this
                # cannot be ordered.
                return False

        # TODO: Revisit PBFT paper, how to make sure that last request of the
        # last view has been ordered? Need change in `VIEW CHANGE` mechanism.
        # Somehow view change needs to communicate what the last request was.
        # Also what if some COMMITs were completely missed in the same view
        return True

    def orderStashedCommits(self):
        # TODO: What if the first few commits were out of order and stashed?
        # `self.ordered` would be empty
        if self.ordered:
            lastOrdered = self.ordered[-1]
            vToRemove = set()
            for v in self.stashedCommitsForOrdering:
                if v < lastOrdered[0] and self.stashedCommitsForOrdering[v]:
                    raise RuntimeError("{} found commits from previous view {}"
                                       " that were not ordered but last ordered"
                                       " is {}".format(self, v, lastOrdered))
                pToRemove = set()
                for p, commit in self.stashedCommitsForOrdering[v].items():
                    if (v == lastOrdered[0] and lastOrdered == (v, p - 1)) or \
                            (v > lastOrdered[0] and
                                self.isLowestCommitInView(commit)):
                        logger.debug("{} ordering stashed commit {}".
                                     format(self, commit))
                        if self.tryOrdering(commit):
                            lastOrdered = (v, p)
                            pToRemove.add(p)

                for p in pToRemove:
                    del self.stashedCommitsForOrdering[v][p]
                if not self.stashedCommitsForOrdering[v]:
                    vToRemove.add(v)

            for v in vToRemove:
                del self.stashedCommitsForOrdering[v]

            # if self.stashedCommitsForOrdering:
            #     self._schedule(self.orderStashedCommits, 2)
            if not self.stashedCommitsForOrdering:
                self.stopRepeating(self.orderStashedCommits)

    def isLowestCommitInView(self, commit):
        # TODO: Assumption: This assumes that at least one commit that was sent
        #  for any request by any node has been received in the view of this
        # commit
        ppSeqNos = []
        for v, p in self.commits:
            if v == commit.viewNo:
                ppSeqNos.append(p)
        return min(ppSeqNos) == commit.ppSeqNo if ppSeqNos else True

    def tryOrdering(self, commit: Commit) -> None:
        """
        Attempt to send an ORDERED request for the specified COMMIT to the
        node.

        :param commit: the COMMIT message
        """
        key = (commit.viewNo, commit.ppSeqNo)
        logger.debug("{} trying to order COMMIT{}".format(self, key))
        reqKey = self.getReqKeyFrom3PhaseKey(key)   # type: Tuple
        digest = self.getDigestFor3PhaseKey(key)
        if not digest:
            logger.error("{} did not find digest for {}, request key {}".
                         format(self, key, reqKey))
            return
        self.doOrder(*key, *reqKey, digest, commit.ppTime)
        return True

    def doOrder(self, viewNo, ppSeqNo, identifier, reqId, digest, ppTime):
        key = (viewNo, ppSeqNo)
        self.addToOrdered(*key)
        ordered = Ordered(self.instId,
                          viewNo,
                          identifier,
                          reqId,
                          ppTime)
        # TODO: Should not order or add to checkpoint while syncing
        # 3 phase state.
        self.send(ordered, TPCStat.OrderSent)
        if key in self.stashingWhileCatchingUp:
            self.stashingWhileCatchingUp.remove(key)
        logger.debug("{} ordered request {}".format(self, (viewNo, ppSeqNo)))
        self.addToCheckpoint(ppSeqNo, digest)

    def processCheckpoint(self, msg: Checkpoint, sender: str):
        if self.checkpoints:
            seqNo = msg.seqNo
            _, firstChk = self.firstCheckPoint
            if firstChk.isStable:
                if firstChk.seqNo == seqNo:
                    self.discard(msg, reason="Checkpoint already stable",
                                 logMethod=logger.debug)
                    return
                if firstChk.seqNo > seqNo:
                    self.discard(msg, reason="Higher stable checkpoint present",
                                 logMethod=logger.debug)
                    return
            for state in self.checkpoints.values():
                if state.seqNo == seqNo:
                    if state.digest == msg.digest:
                        state.receivedDigests[sender] = msg.digest
                        break
                    else:
                        logger.error("{} received an incorrect digest {} for "
                                     "checkpoint {} from {}".format(self,
                                                                    msg.digest,
                                                                    seqNo,
                                                                    sender))
                        return
            if len(state.receivedDigests) == 2*self.f:
                self.markCheckPointStable(msg.seqNo)
        else:
            self.discard(msg, reason="No checkpoints present to tally",
                         logMethod=logger.warn)

    def _newCheckpointState(self, ppSeqNo, digest) -> CheckpointState:
        s, e = ppSeqNo, ppSeqNo + self.config.CHK_FREQ - 1
        logger.debug("{} adding new checkpoint state for {}".
                     format(self, (s, e)))
        state = CheckpointState(ppSeqNo, [digest, ], None, {}, False)
        self.checkpoints[s, e] = state
        return state

    def addToCheckpoint(self, ppSeqNo, digest):
        for (s, e) in self.checkpoints.keys():
            if s <= ppSeqNo <= e:
                state = self.checkpoints[s, e]  # type: CheckpointState
                state.digests.append(digest)
                state = updateNamedTuple(state, seqNo=ppSeqNo)
                self.checkpoints[s, e] = state
                break
        else:
            state = self._newCheckpointState(ppSeqNo, digest)
            s, e = ppSeqNo, ppSeqNo + self.config.CHK_FREQ

        if len(state.digests) == self.config.CHK_FREQ:
            state = updateNamedTuple(state, digest=serialize(state.digests),
                                     digests=[])
            self.checkpoints[s, e] = state
            self.send(Checkpoint(self.instId, self.viewNo, ppSeqNo,
                                 state.digest))

    def markCheckPointStable(self, seqNo):
        previousCheckpoints = []
        for (s, e), state in self.checkpoints.items():
            if e == seqNo:
                state = updateNamedTuple(state, isStable=True)
                self.checkpoints[s, e] = state
                break
            else:
                previousCheckpoints.append((s, e))
        else:
            logger.error("{} could not find {} in checkpoints".
                         format(self, seqNo))
            return
        self.h = seqNo
        for k in previousCheckpoints:
            logger.debug("{} removing previous checkpoint {}".format(self, k))
            self.checkpoints.pop(k)
        self.gc(seqNo)
        logger.debug("{} marked stable checkpoint {}".format(self, (s, e)))
        self.processStashedMsgsForNewWaterMarks()

    def gc(self, tillSeqNo):
        logger.debug("{} cleaning up till {}".format(self, tillSeqNo))
        tpcKeys = set()
        reqKeys = set()
        for (v, p), (reqKey, _) in self.sentPrePrepares.items():
            if p <= tillSeqNo:
                tpcKeys.add((v, p))
                reqKeys.add(reqKey)
        for (v, p), (reqKey, _) in self.prePrepares.items():
            if p <= tillSeqNo:
                tpcKeys.add((v, p))
                reqKeys.add(reqKey)

        logger.debug("{} found {} 3 phase keys to clean".
                     format(self, len(tpcKeys)))
        logger.debug("{} found {} request keys to clean".
                     format(self, len(reqKeys)))

        for k in tpcKeys:
            self.sentPrePrepares.pop(k, None)
            self.prePrepares.pop(k, None)
            self.prepares.pop(k, None)
            self.commits.pop(k, None)
            if k in self.ordered:
                self.ordered.remove(k)

        for k in reqKeys:
            self.requests.pop(k, None)

    def processStashedMsgsForNewWaterMarks(self):
        while self.stashingWhileOutsideWaterMarks:
            item = self.stashingWhileOutsideWaterMarks.pop()
            logger.debug("{} processing stashed item {} after new stable "
                         "checkpoint".format(self, item))

            if isinstance(item, ReqDigest):
                self.doPrePrepare(item)
            elif isinstance(item, tuple) and len(tuple) == 2:
                self.dispatchThreePhaseMsg(*item)
            else:
                logger.error("{} cannot process {} "
                             "from stashingWhileOutsideWaterMarks".
                             format(self, item))

    @property
    def firstCheckPoint(self) -> Tuple[Tuple[int, int], CheckpointState]:
        if not self.checkpoints:
            return None
        else:
            return self.checkpoints.peekitem(0)

    @property
    def lastCheckPoint(self) -> Tuple[Tuple[int, int], CheckpointState]:
        if not self.checkpoints:
            return None
        else:
            return self.checkpoints.peekitem(-1)

    def isPpSeqNoAcceptable(self, ppSeqNo: int):
        return self.h < ppSeqNo <= self.H

    def addToOrdered(self, viewNo: int, ppSeqNo: int):
        self.ordered.add((viewNo, ppSeqNo))

    def enqueuePrePrepare(self, request: PrePrepare, sender: str):
        logger.debug("Queueing pre-prepares due to unavailability of finalised "
                     "Request. Request {} from {}".format(request, sender))
        key = (request.identifier, request.reqId)
        if key not in self.prePreparesPendingReqDigest:
            self.prePreparesPendingReqDigest[key] = []
        self.prePreparesPendingReqDigest[key].append((request, sender))

    def dequeuePrePrepare(self, identifier: int, reqId: int):
        key = (identifier, reqId)
        if key in self.prePreparesPendingReqDigest:
            pps = self.prePreparesPendingReqDigest[key]
            for (pp, sender) in pps:
                logger.debug("{} popping stashed PRE-PREPARE{}".
                             format(self, key))
                if pp.digest == self.requests.digest(key):
                    self.prePreparesPendingReqDigest.pop(key)
                    self.processPrePrepare(pp, sender)
                    logger.debug(
                        "{} processed {} PRE-PREPAREs waiting for finalised "
                        "request for identifier {} and reqId {}".
                        format(self, pp, identifier, reqId))
                    break

    def enqueuePrepare(self, request: Prepare, sender: str):
        logger.debug("Queueing prepares due to unavailability of PRE-PREPARE. "
                     "Request {} from {}".format(request, sender))
        key = (request.viewNo, request.ppSeqNo)
        if key not in self.preparesWaitingForPrePrepare:
            self.preparesWaitingForPrePrepare[key] = deque()
        self.preparesWaitingForPrePrepare[key].append((request, sender))

    def dequeuePrepares(self, viewNo: int, ppSeqNo: int):
        key = (viewNo, ppSeqNo)
        if key in self.preparesWaitingForPrePrepare:
            i = 0
            # Keys of pending prepares that will be processed below
            while self.preparesWaitingForPrePrepare[key]:
                prepare, sender = self.preparesWaitingForPrePrepare[
                    key].popleft()
                logger.debug("{} popping stashed PREPARE{}".format(self, key))
                self.processPrepare(prepare, sender)
                i += 1
            self.preparesWaitingForPrePrepare.pop(key)
            logger.debug("{} processed {} PREPAREs waiting for PRE-PREPARE for"
                         " view no {} and seq no {}".
                         format(self, i, viewNo, ppSeqNo))

    def enqueueCommit(self, request: Commit, sender: str):
        logger.debug("Queueing commit due to unavailability of PREPARE. "
                     "Request {} from {}".format(request, sender))
        key = (request.viewNo, request.ppSeqNo)
        if key not in self.commitsWaitingForPrepare:
            self.commitsWaitingForPrepare[key] = deque()
        self.commitsWaitingForPrepare[key].append((request, sender))

    def dequeueCommits(self, viewNo: int, ppSeqNo: int):
        key = (viewNo, ppSeqNo)
        if key in self.commitsWaitingForPrepare:
            i = 0
            # Keys of pending prepares that will be processed below
            while self.commitsWaitingForPrepare[key]:
                commit, sender = self.commitsWaitingForPrepare[
                    key].popleft()
                logger.debug("{} popping stashed COMMIT{}".format(self, key))
                self.processCommit(commit, sender)
                i += 1
            self.commitsWaitingForPrepare.pop(key)
            logger.debug("{} processed {} COMMITs waiting for PREPARE for"
                         " view no {} and seq no {}".
                         format(self, i, viewNo, ppSeqNo))

    def getDigestFor3PhaseKey(self, key: ThreePhaseKey) -> Optional[str]:
        reqKey = self.getReqKeyFrom3PhaseKey(key)
        digest = self.requests.digest(reqKey)
        if not digest:
            logger.debug("{} could not find digest in sent or received "
                         "PRE-PREPAREs or PREPAREs for 3 phase key {} and req "
                         "key {}".format(self, key, reqKey))
            return None
        else:
            return digest

    def getReqKeyFrom3PhaseKey(self, key: ThreePhaseKey):
        reqKey = None
        if key in self.sentPrePrepares:
            reqKey = self.sentPrePrepares[key][0]
        elif key in self.prePrepares:
            reqKey = self.prePrepares[key][0]
        elif key in self.prepares:
            reqKey = self.prepares[key][0]
        else:
            logger.debug("Could not find request key for 3 phase key {}".
                         format(key))
        return reqKey

    @property
    def threePhaseState(self):
        # TODO: This method is incomplete
        # Gets the current stable and unstable checkpoints and creates digest
        # of unstable checkpoints
        if self.checkpoints:
            pass
        else:
            state = []
        return ThreePCState(self.instId, state)

    def process3PhaseState(self, msg: ThreePCState, sender: str):
        # TODO: This is not complete
        pass

    def send(self, msg, stat=None) -> None:
        """
        Send a message to the node on which this replica resides.

        :param msg: the message to send
        """
        logger.display("{} sending {}".format(self, msg.__class__.__name__),
                       extra={"cli": True})
        logger.trace("{} sending {}".format(self, msg))
        if stat:
            self.stats.inc(stat)
        self.outBox.append(msg)
Пример #33
0
def audit_project(config):
    """
    Audit a project according to the configured values provided
    """

    logging.info("START %s %s" % (config.PROJECT, config.START))

    counts = {'nextcloudNodeCount': 0, 'filesystemNodeCount': 0, 'idaNodeCount': 0, 'metaxNodeCount': 0}
    nodes = SortedDict({})

    # Populate auditing data objects for all nodes in scope according to the configured values provided

    add_nextcloud_nodes(nodes, counts, config)
    add_filesystem_nodes(nodes, counts, config)
    add_frozen_files(nodes, counts, config)
    add_metax_files(nodes, counts, config)

    # Iterate over all nodes, logging and reporting all errors

    invalidNodes = SortedDict({})
    invalidNodeCount = 0

    if config.DEBUG == 'true':
        # Only count and report files in debug progress, because we track how many files are in in project, not nodes/folders/etc.
        fileCount = 0

    for pathname, node in nodes.items():

        errors = SortedDict({})

        # Determine whether the node is in the frozen area based on the pathname
    
        is_frozen_area_pathname = False

        if pathname[:1] == 'f':
            is_frozen_area_pathname = True

        # Determine where the node exists

        try:
            filesystem = node['filesystem']
        except:
            filesystem = False

        try:
            nextcloud = node['nextcloud']
        except:
            nextcloud = False

        try:
            ida = node['ida']
        except:
            ida = False

        try:
            metax = node['metax']
        except:
            metax = False

        if config.DEBUG == 'true':
            if (nextcloud and nextcloud['type'] == 'file') or (filesystem and filesystem['type'] == 'file') or (ida and ida['type'] == 'file') or (metax and metax['type'] == 'file'):
                fileCount = fileCount + 1
            sys.stderr.write("%s: auditing: %d %s\n" % (config.PROJECT, fileCount, pathname))

        # Check that node exists in both filesystem and Nextcloud, and with same type

        if filesystem and not nextcloud:
            errors['Node does not exist in Nextcloud'] = True

        if nextcloud and not filesystem:
            errors['Node does not exist in filesystem'] = True

        if filesystem and nextcloud and filesystem['type'] != nextcloud['type']:
            errors['Node type different for filesystem and Nextcloud'] = True

        # If filesystem and nextcloud agree node is a file, apply further checks...        

        if filesystem and filesystem['type'] == 'file' and nextcloud and nextcloud['type'] == 'file':

            if filesystem and nextcloud and filesystem['size'] != nextcloud['size']:
                errors['Node size different for filesystem and Nextcloud'] = True

            if config.IGNORE_TIMESTAMPS != 'true':
                if filesystem and nextcloud and filesystem['modified'] != nextcloud['modified']:
                    errors['Node modification timestamp different for filesystem and Nextcloud'] = True

        # If pathname is in the frozen area, and is known to either Nextcloud or the filesystem
        # as a file; check that the file is registered both as frozen by the IDA app and is published
        # to metax, is also known as a file by both the filesystem and Nextcloud, is replicated properly,
        # and that all relevant file details agree.

        if is_frozen_area_pathname and (ida or metax or (filesystem and filesystem['type'] == 'file') or (nextcloud and nextcloud['type'] == 'file')):

            # check if IDA details exist for frozen file
            if not ida: 
                errors['Node does not exist in IDA'] = True

            # check if metax details exist for frozen file
            if not metax:
                errors['Node does not exist in Metax'] = True

            # check if details exist only in IDA or metax for frozen file
            if not filesystem and not nextcloud:
                errors['Node does not exist in filesystem'] = True
                errors['Node does not exist in Nextcloud'] = True

            # check node type agreement between IDA and filesystem
            if ida and filesystem and filesystem['type'] == 'folder':
                errors['Node type different for filesystem and IDA'] = True

            # check node type agreement between IDA and Nextcloud
            if ida and nextcloud and nextcloud['type'] == 'folder':
                errors['Node type different for Nextcloud and IDA'] = True

            # check node type agreement between metax and filesystem
            if metax and filesystem and filesystem['type'] == 'folder':
                errors['Node type different for filesystem and Metax'] = True

            # check node type agreement between metax and nextcloud
            if metax and nextcloud and nextcloud['type'] == 'folder':
                errors['Node type different for Nextcloud and Metax'] = True

            # if known in both IDA and filesystem, check if file details agree
            if ida and filesystem and filesystem['type'] == 'file':

                if ida['size'] != filesystem['size']:
                    errors['Node size different for filesystem and IDA'] = True

                if config.IGNORE_TIMESTAMPS != 'true':
                    if ida['modified'] != filesystem['modified']:
                        errors['Node modification timestamp different for filesystem and IDA'] = True

            # if known in both IDA and nextcloud, check if file details agree
            if ida and nextcloud and nextcloud['type'] == 'file':

                if ida['size'] != nextcloud['size']:
                    errors['Node size different for Nextcloud and IDA'] = True

                if config.IGNORE_TIMESTAMPS != 'true':
                    if ida['modified'] != nextcloud['modified']:
                        errors['Node modification timestamp different for Nextcloud and IDA'] = True

            # if known in both metax and filesystem, check if file details agree
            if metax and filesystem and filesystem['type'] == 'file':

                if metax['size'] != filesystem['size']:
                    errors['Node size different for filesystem and Metax'] = True

                if config.IGNORE_TIMESTAMPS != 'true':
                    if metax['modified'] != filesystem['modified']:
                        errors['Node modification timestamp different for filesystem and Metax'] = True

            # if known in both metax and nextcloud, check if file details agree
            if metax and nextcloud and nextcloud['type'] == 'file':

                if metax['size'] != nextcloud['size']:
                    errors['Node size different for Nextcloud and Metax'] = True

                if config.IGNORE_TIMESTAMPS != 'true':
                    if metax['modified'] != nextcloud['modified']:
                        errors['Node modification timestamp different for Nextcloud and Metax'] = True

            # if known in both IDA and metax and filesystem, check if file details agree
            if ida and metax:

                if ida['size'] != metax['size']:
                    errors['Node size different for IDA and Metax'] = True

                if config.IGNORE_TIMESTAMPS != 'true':
                    if ida['modified'] != metax['modified']:
                        errors['Node modification timestamp different for IDA and Metax'] = True
                    if ida['frozen'] != metax['frozen']:
                        errors['Node frozen timestamp different for IDA and Metax'] = True

                if ida['checksum'] != metax['checksum']:
                    errors['Node checksum different for IDA and Metax'] = True

                if ida['pid'] != metax['pid']:
                    errors['Node pid different for IDA and Metax'] = True

            # if known in IDA and replication timestamp defined in IDA details, check if file details agree
            if ida:
                
                replicated = ida.get('replicated', False)
    
                if replicated == "None" or replicated == None:
                    replicated = False

                if replicated != False:

                    full_pathname = "%s/projects/%s%s" % (config.DATA_REPLICATION_ROOT, config.PROJECT, pathname[6:])

                    #if config.DEBUG == 'true':
                    #    sys.stderr.write("REPLICATION PATHNAME: %s\n" % full_pathname)

                    path = Path(full_pathname)

                    if path.exists():

                        if path.is_file():

                            fsstat = os.stat(full_pathname)
                            size = fsstat.st_size

                            node['replication'] = {'type': 'file', 'size': size}

                            if ida['size'] != size:
                                errors['Node size different for IDA and replication'] = True

                        else:

                            node['replication'] = {'type': 'folder'}
                            errors['Node type different for IDA and replication'] = True

                    else:
                        errors['Node does not exist in replication'] = True

        # If any errors were detected, add the node to the set of invalid nodes
        # and increment the invalid node count

        if len(errors) > 0:
            node['errors'] = list(errors.keys())
            invalidNodes[pathname] = node
            invalidNodeCount = invalidNodeCount + 1

            if config.DEBUG == 'true':
                for error in node['errors']:
                    sys.stderr.write("Error: %s\n" % error)

    # Output report

    sys.stdout.write("{\n")
    sys.stdout.write("\"project\": %s,\n" % str(json.dumps(config.PROJECT)))
    sys.stdout.write("\"start\": %s,\n" % str(json.dumps(config.START)))
    sys.stdout.write("\"end\": %s,\n" % str(json.dumps(datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"))))
    sys.stdout.write("\"filesystemNodeCount\": %d,\n" % counts['filesystemNodeCount'])
    sys.stdout.write("\"nextcloudNodeCount\": %d,\n" % counts['nextcloudNodeCount'])
    sys.stdout.write("\"idaNodeCount\": %d,\n" % counts['idaNodeCount'])
    sys.stdout.write("\"metaxNodeCount\": %d,\n" % counts['metaxNodeCount'])
    sys.stdout.write("\"invalidNodeCount\": %d" % invalidNodeCount)

    if invalidNodeCount > 0:

        first = True

        sys.stdout.write(",\n\"invalidNodes\": {\n")

        for pathname, node in invalidNodes.items():

            if not first:
                sys.stdout.write(",\n")

            first = False

            sys.stdout.write("%s: {" % str(json.dumps(pathname)))
            sys.stdout.write("\n\"errors\": %s" % str(json.dumps(node['errors'])))

            try:
                node_details = node['filesystem']
                sys.stdout.write(",\n\"filesystem\": {")
                sys.stdout.write("\n\"type\": \"%s\"" % node_details['type'])
                if node_details['type'] == 'file':
                    try:
                        sys.stdout.write(",\n\"size\": %d" % node_details['size'])
                    except:
                        pass
                    try:
                        sys.stdout.write(",\n\"modified\": \"%s\"" % node_details['modified'])
                    except:
                        pass
                sys.stdout.write("}")
            except:
                pass

            try:
                node_details = node['nextcloud']
                sys.stdout.write(",\n\"nextcloud\": {")
                sys.stdout.write("\n\"type\": \"%s\"" % node_details['type'])
                if node_details['type'] == 'file':
                    try:
                        sys.stdout.write(",\n\"size\": %d" % node_details['size'])
                    except:
                        pass
                    try:
                        sys.stdout.write(",\n\"modified\": \"%s\"" % node_details['modified'])
                    except:
                        pass
                sys.stdout.write("\n}")
            except:
                pass

            try:
                node_details = node['ida']
                sys.stdout.write(",\n\"ida\": {")
                sys.stdout.write("\n\"type\": \"file\"")
                try:
                    sys.stdout.write(",\n\"size\": %d" % node_details['size'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"pid\": \"%s\"" % node_details['pid'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"checksum\": \"%s\"" % node_details['checksum'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"modified\": \"%s\"" % node_details['modified'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"frozen\": \"%s\"" % node_details['frozen'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"replicated\": \"%s\"" % node_details['replicated'])
                except:
                    pass
                sys.stdout.write("\n}")
            except:
                pass

            try:
                node_details = node['metax']
                sys.stdout.write(",\n\"metax\": {")
                sys.stdout.write("\n\"type\": \"file\"")
                try:
                    sys.stdout.write(",\n\"size\": %d" % node_details['size'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"pid\": \"%s\"" % node_details['pid'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"checksum\": \"%s\"" % node_details['checksum'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"modified\": \"%s\"" % node_details['modified'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"frozen\": \"%s\"" % node_details['frozen'])
                except:
                    pass
                sys.stdout.write("\n}")
            except:
                pass

            try:
                node_details = node['replication']
                sys.stdout.write(",\n\"replication\": {")
                sys.stdout.write("\n\"type\": \"file\"")
                try:
                    sys.stdout.write(",\n\"size\": %d" % node_details['size'])
                except:
                    pass
                try:
                    sys.stdout.write(",\n\"modified\": \"%s\"" % node_details['modified'])
                except:
                    pass
                sys.stdout.write("}")
            except:
                pass
        
            sys.stdout.write("\n}")

        sys.stdout.write("\n}\n")

    sys.stdout.write("}\n")

    logging.info("DONE")
Пример #34
0
class PiecewiseConstantFunction(Generic[T]):
    def __init__(self, initial_value: float = 0) -> None:
        """ Initialize the constant function to a particular value

        :param initial_value: the starting value for the function
        """
        self.breakpoints = SortedDict()
        self._initial_value: float = initial_value

    def add_breakpoint(self,
                       xval: XValue[T],
                       yval: float,
                       squash: bool = True) -> None:
        """ Add a breakpoint to the function and update the value

        Let f(x) be the original function, and next_bp be the first breakpoint > xval; after calling
        this method, the function will be modified to f'(x) = yval for x \in [xval, next_bp)

        :param xval: the x-position of the breakpoint to add/modify
        :param yval: the value to set the function to at xval
        :param squash: if True and f(xval) = yval before calling this method, the function will remain unchanged
        """
        if squash and self.call(xval) == yval:
            return
        self.breakpoints[xval] = yval

    def add_delta(self, xval: XValue[T], delta: float) -> None:
        """ Modify the function value for x >= xval

        Let f(x) be the original function; After calling this method,
        the function will be modified to f'(x) = f(x) + delta for all x >= xval

        :param xval: the x-position of the breakpoint to add/modify
        :param delta: the amount to shift the function value by at xval
        """
        if delta == 0:
            return

        if xval not in self.breakpoints:
            self.breakpoints[xval] = self.call(xval)

        for x in self.breakpoints.irange(xval):
            self.breakpoints[x] += delta

        self.values.cache_clear()
        self.integrals.cache_clear()

    def call(self, xval: XValue[T]) -> float:
        """ Compute the output of the function at a point

        :param xval: the x-position to compute
        :returns: f(xval)
        """
        if len(self.breakpoints) == 0 or xval < self.breakpoints.keys()[0]:
            return self._initial_value
        else:
            lower_index = self.breakpoints.bisect(xval) - 1
            return self.breakpoints.values()[lower_index]

    def _breakpoint_info(
        self, index: Optional[int]
    ) -> Tuple[Optional[int], Optional[XValue[T]], float]:
        """ Helper function for computing breakpoint information

        :param index: index of the breakpoint to compute
        :returns: (index, breakpoint, value)
          * index is the breakpoint index (if it exists), or None if we're off the end
          * breakpoint is the x-value of the breakpoint, or None if we're off the end
          * value is f(breakpoint), or f(last_breakpoint) if we're off the end
        """
        try:
            breakpoint, value = self.breakpoints.peekitem(index)
        except IndexError:
            index = None
            breakpoint, value = None, self.breakpoints.values()[-1]
        return (index, breakpoint, value)

    @lru_cache(maxsize=_LRU_CACHE_SIZE
               )  # cache results of calls to this function
    def values(self, start: XValue[T], stop: XValue[T],
               step: XValueDiff[T]) -> 'SortedDict[XValue[T], float]':
        """ Compute a sequence of values of the function

        This is more efficient than [self.call(xval) for xval in range(start, stop, step)] because each self.call(..)
        takes O(log n) time due to the binary tree structure of self._breakpoints.  This method can compute the range
        of values in linear time in the range, which is significantly faster for large value ranges.

        :param start: lower bound of value sequence
        :param stop: upper bound of value sequence
        :param step: width between points in the sequence
        :returns: a SortedDict of the values of the function between start and stop, with the x-distance between
            each data-point equal to `step`; like normal "range" functions the right endpoint is not included
        """

        step = step or (stop - start)
        if len(self.breakpoints) == 0:
            num_values = int(math.ceil((stop - start) / step))
            return SortedDict([(start + step * i, self._initial_value)
                               for i in range(num_values)])

        curr_xval = start
        curr_value = self.call(start)
        next_index, next_breakpoint, next_value = self._breakpoint_info(
            self.breakpoints.bisect(start))

        sequence = SortedDict()
        while curr_xval < stop:
            sequence[curr_xval] = curr_value

            next_xval = min(stop, curr_xval + step)
            while next_breakpoint and next_xval >= next_breakpoint:
                assert next_index is not None  # if next_breakpoint is set, next_index should also be set
                curr_value = next_value
                next_index, next_breakpoint, next_value = self._breakpoint_info(
                    next_index + 1)
            curr_xval = next_xval

        return sequence

    @lru_cache(maxsize=_LRU_CACHE_SIZE
               )  # cache results of calls to this function
    def integrals(
        self,
        start: XValue[T],
        stop: XValue[T],
        step: XValueDiff[T],
        transform: Callable[[XValueDiff[T]], float] = lambda x: cast(float, x),
    ) -> 'SortedDict[XValue[T], float]':
        """ Compute a sequence of integrals of the function

        :param start: lower bound of integral sequence
        :param stop: upper bound of integral sequence
        :param step: width of each "chunk" of the integral sequence
        :param transform: function to apply to x-widths before computing the integral
        :returns: a SortedDict of the numeric integral values of the function between start and stop;
            each integral has a range of size `step`, and the key-value is the left endpoint of the chunk
        """
        step = step or (stop - start)
        if len(self.breakpoints) == 0:
            # If there are no breakpoints, just split up the range into even widths and compute
            # (width * self._initial_value) for each chunk.
            step_width = transform(step)
            range_width = transform(stop - start)
            num_full_chunks = int(range_width // step_width)
            sequence = SortedDict([(start + step * i,
                                    step_width * self._initial_value)
                                   for i in range(num_full_chunks)])

            # If the width does not evenly divide the range, compute the last chunk separately
            if range_width % step_width != 0:
                sequence[
                    start + step *
                    num_full_chunks] = range_width % step_width * self._initial_value
            return sequence

        # Set up starting loop parameters
        curr_xval = start
        curr_value = self.call(start)
        next_index, next_breakpoint, next_value = self._breakpoint_info(
            self.breakpoints.bisect(start))

        # Loop through the entire range and compute the integral of each chunk
        sequence = SortedDict()
        while curr_xval < stop:
            orig_xval = curr_xval
            next_xval = min(stop, curr_xval + step)

            # For each breakpoint in [curr_xval, next_xval), compute the area of that sub-chunk
            next_integral: float = 0
            while next_breakpoint and next_xval >= next_breakpoint:
                assert next_index is not None  # if next_breakpoint is set, next_index should also be set
                next_integral += transform(next_breakpoint -
                                           curr_xval) * curr_value
                curr_xval = next_breakpoint
                curr_value = next_value
                next_index, next_breakpoint, next_value = self._breakpoint_info(
                    next_index + 1)

            # Handle any remaining width between the last breakpoint and the end of the chunk
            next_integral += transform(next_xval - curr_xval) * curr_value
            sequence[orig_xval] = next_integral

            curr_xval = next_xval

        return sequence

    def integral(
        self,
        start: XValue[T],
        stop: XValue[T],
        transform: Callable[[XValueDiff[T]], float] = lambda x: cast(float, x),
    ) -> float:
        """ Helper function to compute the integral of the whole specified range

        :param start: lower bound of the integral
        :param stop: upper bound of the integral
        :returns: the integral of the function between start and stop
        """
        return self.integrals(start, stop, (stop - start),
                              transform).values()[0]

    def __str__(self) -> str:
        ret = f'{self._initial_value}, x < {self.breakpoints.keys()[0]}\n'
        for xval, yval in self.breakpoints.items():
            ret += f'{yval}, x >= {xval}\n'
        return ret

    def __add__(
        self, other: 'PiecewiseConstantFunction[T]'
    ) -> 'PiecewiseConstantFunction[T]':
        new_func: 'PiecewiseConstantFunction[T]' = PiecewiseConstantFunction(
            self._initial_value + other._initial_value)
        for xval, y0, y1 in _merged_breakpoints(self, other):
            new_func.add_breakpoint(xval, y0 + y1)
        return new_func

    def __sub__(
        self, other: 'PiecewiseConstantFunction[T]'
    ) -> 'PiecewiseConstantFunction[T]':
        new_func: 'PiecewiseConstantFunction[T]' = PiecewiseConstantFunction(
            self._initial_value - other._initial_value)
        for xval, y0, y1 in _merged_breakpoints(self, other):
            new_func.add_breakpoint(xval, y0 - y1)
        return new_func

    def __mul__(
        self, other: 'PiecewiseConstantFunction[T]'
    ) -> 'PiecewiseConstantFunction[T]':
        new_func: 'PiecewiseConstantFunction[T]' = PiecewiseConstantFunction(
            self._initial_value * other._initial_value)
        for xval, y0, y1 in _merged_breakpoints(self, other):
            new_func.add_breakpoint(xval, y0 * y1)
        return new_func

    def __truediv__(
        self, other: 'PiecewiseConstantFunction[T]'
    ) -> 'PiecewiseConstantFunction[T]':
        try:
            new_func: 'PiecewiseConstantFunction[T]' = PiecewiseConstantFunction(
                self._initial_value / other._initial_value)
        except ZeroDivisionError:
            new_func = PiecewiseConstantFunction()

        for xval, y0, y1 in _merged_breakpoints(self, other):
            try:
                new_func.add_breakpoint(xval, y0 / y1)
            except ZeroDivisionError:
                new_func.add_breakpoint(xval, 0)
        return new_func
Пример #35
0
class IntervalTree(collections.MutableSet):
    """
    A binary lookup tree of intervals.
    The intervals contained in the tree are represented using ``Interval(a, b, data)`` objects.
    Each such object represents a half-open interval ``[a, b)`` with optional data.
    
    Examples:
    ---------
    
    Initialize a blank tree::
    
        >>> tree = IntervalTree()
        >>> tree
        IntervalTree()
    
    Initialize a tree from an iterable set of Intervals in O(n * log n)::
    
        >>> tree = IntervalTree([Interval(-10, 10), Interval(-20.0, -10.0)])
        >>> tree
        IntervalTree([Interval(-20.0, -10.0), Interval(-10, 10)])
        >>> len(tree)
        2
    
    Note that this is a set, i.e. repeated intervals are ignored. However,
    Intervals with different data fields are regarded as different::
    
        >>> tree = IntervalTree([Interval(-10, 10), Interval(-10, 10), Interval(-10, 10, "x")])
        >>> tree
        IntervalTree([Interval(-10, 10), Interval(-10, 10, 'x')])
        >>> len(tree)
        2
    
    Insertions::
        >>> tree = IntervalTree()
        >>> tree[0:1] = "data"
        >>> tree.add(Interval(10, 20))
        >>> tree.addi(19.9, 20)
        >>> tree
        IntervalTree([Interval(0, 1, 'data'), Interval(10, 20), Interval(19.9, 20)])
        >>> tree.update([Interval(19.9, 20.1), Interval(20.1, 30)])
        >>> len(tree)
        5

        Inserting the same Interval twice does nothing::
            >>> tree = IntervalTree()
            >>> tree[-10:20] = "arbitrary data"
            >>> tree[-10:20] = None  # Note that this is also an insertion
            >>> tree
            IntervalTree([Interval(-10, 20), Interval(-10, 20, 'arbitrary data')])
            >>> tree[-10:20] = None  # This won't change anything
            >>> tree[-10:20] = "arbitrary data" # Neither will this
            >>> len(tree)
            2

    Deletions::
        >>> tree = IntervalTree(Interval(b, e) for b, e in [(-10, 10), (-20, -10), (10, 20)])
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(-10, 10), Interval(10, 20)])
        >>> tree.remove(Interval(-10, 10))
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(10, 20)])
        >>> tree.remove(Interval(-10, 10))
        Traceback (most recent call last):
        ...
        ValueError
        >>> tree.discard(Interval(-10, 10))  # Same as remove, but no exception on failure
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(10, 20)])
        
    Delete intervals, overlapping a given point::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.1)
        >>> tree
        IntervalTree([Interval(-1.1, 1.1)])
        
    Delete intervals, overlapping an interval::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_overlap(0, 0.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.7, 1.8)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.6, 1.6)  # Null interval does nothing
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.6, 1.5)  # Ditto
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        
    Delete intervals, enveloped in the range::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_envelop(-1.0, 1.5)
        >>> tree
        IntervalTree([Interval(-1.1, 1.1), Interval(0.5, 1.7)])
        >>> tree.remove_envelop(-1.1, 1.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_envelop(0.5, 1.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_envelop(0.5, 1.7)
        >>> tree
        IntervalTree()
        
    Point/interval overlap queries::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> assert tree[-1.1]         == set([Interval(-1.1, 1.1)])
        >>> assert tree.search(1.1)   == set([Interval(-0.5, 1.5), Interval(0.5, 1.7)])   # Same as tree[1.1]
        >>> assert tree[-0.5:0.5]     == set([Interval(-0.5, 1.5), Interval(-1.1, 1.1)])  # Interval overlap query
        >>> assert tree.search(1.5, 1.5) == set()                                         # Same as tree[1.5:1.5]
        >>> assert tree.search(1.5) == set([Interval(0.5, 1.7)])                          # Same as tree[1.5]

        >>> assert tree.search(1.7, 1.8) == set()

    Envelop queries::
    
        >>> assert tree.search(-0.5, 0.5, strict=True) == set()
        >>> assert tree.search(-0.4, 1.7, strict=True) == set([Interval(0.5, 1.7)])
        
    Membership queries::

        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> Interval(-0.5, 0.5) in tree
        False
        >>> Interval(-1.1, 1.1) in tree
        True
        >>> Interval(-1.1, 1.1, "x") in tree
        False
        >>> tree.overlaps(-1.1)
        True
        >>> tree.overlaps(1.7)
        False
        >>> tree.overlaps(1.7, 1.8)
        False
        >>> tree.overlaps(-1.2, -1.1)
        False
        >>> tree.overlaps(-1.2, -1.0)
        True
    
    Sizing::

        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> len(tree)
        3
        >>> tree.is_empty()
        False
        >>> IntervalTree().is_empty()
        True
        >>> not tree
        False
        >>> not IntervalTree()
        True
        >>> print(tree.begin())    # using print() because of floats in Python 2.6
        -1.1
        >>> print(tree.end())      # ditto
        1.7
        
    Iteration::

        >>> tree = IntervalTree([Interval(-11, 11), Interval(-5, 15), Interval(5, 17)])
        >>> [iv.begin for iv in sorted(tree)]
        [-11, -5, 5]
        >>> assert tree.items() == set([Interval(-5, 15), Interval(-11, 11), Interval(5, 17)])

    Copy- and typecasting, pickling::
    
        >>> tree0 = IntervalTree([Interval(0, 1, "x"), Interval(1, 2, ["x"])])
        >>> tree1 = IntervalTree(tree0)  # Shares Interval objects
        >>> tree2 = tree0.copy()         # Shallow copy (same as above, as Intervals are singletons)
        >>> import pickle
        >>> tree3 = pickle.loads(pickle.dumps(tree0))  # Deep copy
        >>> list(tree0[1])[0].data[0] = "y"  # affects shallow copies, but not deep copies
        >>> tree0
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree1
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree2
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree3
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['x'])])
        
    Equality testing::
    
        >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1)])
        True
        >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1, "x")])
        False
    """
    @classmethod
    def from_tuples(cls, tups):
        """
        Create a new IntervalTree from an iterable of 2- or 3-tuples,
         where the tuple lists begin, end, and optionally data.
        """
        ivs = [Interval(*t) for t in tups]
        return IntervalTree(ivs)

    def __init__(self, intervals=None):
        """
        Set up a tree. If intervals is provided, add all the intervals 
        to the tree.
        
        Completes in O(n*log n) time.
        """
        intervals = set(intervals) if intervals is not None else set()
        for iv in intervals:
            if iv.is_null():
                raise ValueError(
                    "IntervalTree: Null Interval objects not allowed in IntervalTree:"
                    " {0}".format(iv))
        self.all_intervals = intervals
        self.top_node = Node.from_intervals(self.all_intervals)
        self.boundary_table = SortedDict()
        for iv in self.all_intervals:
            self._add_boundaries(iv)

    def copy(self):
        """
        Construct a new IntervalTree using shallow copies of the 
        intervals in the source tree.
        
        Completes in O(n*log n) time.
        :rtype: IntervalTree
        """
        return IntervalTree(iv.copy() for iv in self)

    def _add_boundaries(self, interval):
        """
        Records the boundaries of the interval in the boundary table.
        """
        begin = interval.begin
        end = interval.end
        if begin in self.boundary_table:
            self.boundary_table[begin] += 1
        else:
            self.boundary_table[begin] = 1

        if end in self.boundary_table:
            self.boundary_table[end] += 1
        else:
            self.boundary_table[end] = 1

    def _remove_boundaries(self, interval):
        """
        Removes the boundaries of the interval from the boundary table.
        """
        begin = interval.begin
        end = interval.end
        if self.boundary_table[begin] == 1:
            del self.boundary_table[begin]
        else:
            self.boundary_table[begin] -= 1

        if self.boundary_table[end] == 1:
            del self.boundary_table[end]
        else:
            self.boundary_table[end] -= 1

    def add(self, interval):
        """
        Adds an interval to the tree, if not already present.
        
        Completes in O(log n) time.
        """
        if interval in self:
            return

        if interval.is_null():
            raise ValueError(
                "IntervalTree: Null Interval objects not allowed in IntervalTree:"
                " {0}".format(interval))

        if not self.top_node:
            self.top_node = Node.from_interval(interval)
        else:
            self.top_node = self.top_node.add(interval)
        self.all_intervals.add(interval)
        self._add_boundaries(interval)

    append = add

    def addi(self, begin, end, data=None):
        """
        Shortcut for add(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.add(Interval(begin, end, data))

    appendi = addi

    def update(self, intervals):
        """
        Given an iterable of intervals, add them to the tree.
        
        Completes in O(m*log(n+m), where m = number of intervals to 
        add.
        """
        for iv in intervals:
            self.add(iv)

    def extend(self, intervals):
        """
        Deprecated: Replaced by update().
        """
        warn(
            "IntervalTree.extend() has been deprecated. Consider using update() instead",
            DeprecationWarning)
        self.update(intervals)

    def remove(self, interval):
        """
        Removes an interval from the tree, if present. If not, raises 
        ValueError.
        
        Completes in O(log n) time.
        """
        #self.verify()
        if interval not in self:
            #print(self.all_intervals)
            raise ValueError
        self.top_node = self.top_node.remove(interval)
        self.all_intervals.remove(interval)
        self._remove_boundaries(interval)
        #self.verify()

    def removei(self, begin, end, data=None):
        """
        Shortcut for remove(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.remove(Interval(begin, end, data))

    def discard(self, interval):
        """
        Removes an interval from the tree, if present. If not, does 
        nothing.
        
        Completes in O(log n) time.
        """
        if interval not in self:
            return
        self.all_intervals.discard(interval)
        self.top_node = self.top_node.discard(interval)
        self._remove_boundaries(interval)

    def discardi(self, begin, end, data=None):
        """
        Shortcut for discard(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.discard(Interval(begin, end, data))

    def difference(self, other):
        """
        Returns a new tree, comprising all intervals in self but not
        in other.
        """
        ivs = set()
        for iv in self:
            if iv not in other:
                ivs.add(iv)
        return IntervalTree(ivs)

    def difference_update(self, other):
        """
        Removes all intervals in other from self.
        """
        for iv in other:
            self.discard(iv)

    def union(self, other):
        """
        Returns a new tree, comprising all intervals from self
        and other.
        """
        return IntervalTree(set(self).union(other))

    def intersection(self, other):
        """
        Returns a new tree of all intervals common to both self and
        other.
        """
        ivs = set()
        shorter, longer = sorted([self, other], key=len)
        for iv in shorter:
            if iv in longer:
                ivs.add(iv)
        return IntervalTree(ivs)

    def intersection_update(self, other):
        """
        Removes intervals from self unless they also exist in other.
        """
        for iv in self:
            if iv not in other:
                self.remove(iv)

    def symmetric_difference(self, other):
        """
        Return a tree with elements only in self or other but not
        both.
        """
        if not isinstance(other, set): other = set(other)
        me = set(self)
        ivs = me - other + (other - me)
        return IntervalTree(ivs)

    def symmetric_difference_update(self, other):
        """
        Throws out all intervals except those only in self or other,
        not both.
        """
        other = set(other)
        for iv in self:
            if iv in other:
                self.remove(iv)
                other.remove(iv)
        self.update(other)

    def remove_overlap(self, begin, end=None):
        """
        Removes all intervals overlapping the given point or range.
        
        Completes in O((r+m)*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * r = size of the search range (this is 1 for a point)
        """
        hitlist = self.search(begin, end)
        for iv in hitlist:
            self.remove(iv)

    def remove_envelop(self, begin, end):
        """
        Removes all intervals completely enveloped in the given range.
        
        Completes in O((r+m)*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * r = size of the search range (this is 1 for a point)
        """
        hitlist = self.search(begin, end, strict=True)
        for iv in hitlist:
            self.remove(iv)

    def chop(self, begin, end, datafunc=None):
        """
        Like remove_envelop(), but trims back Intervals hanging into
        the chopped area so that nothing overlaps.
        """
        insertions = set()
        begin_hits = [iv for iv in self[begin] if iv.begin < begin]
        end_hits = [iv for iv in self[end] if iv.end > end]

        if datafunc:
            for iv in begin_hits:
                insertions.add(Interval(iv.begin, begin, datafunc(iv, True)))
            for iv in end_hits:
                insertions.add(Interval(end, iv.end, datafunc(iv, False)))
        else:
            for iv in begin_hits:
                insertions.add(Interval(iv.begin, begin, iv.data))
            for iv in end_hits:
                insertions.add(Interval(end, iv.end, iv.data))

        self.remove_envelop(begin, end)
        self.difference_update(begin_hits)
        self.difference_update(end_hits)
        self.update(insertions)

    def slice(self, point, datafunc=None):
        """
        Split Intervals that overlap point into two new Intervals. if
        specified, uses datafunc(interval, islower=True/False) to
        set the data field of the new Intervals.
        :param point: where to slice
        :param datafunc(interval, isupper): callable returning a new
        value for the interval's data field
        """
        hitlist = set(iv for iv in self[point] if iv.begin < point)
        insertions = set()
        if datafunc:
            for iv in hitlist:
                insertions.add(Interval(iv.begin, point, datafunc(iv, True)))
                insertions.add(Interval(point, iv.end, datafunc(iv, False)))
        else:
            for iv in hitlist:
                insertions.add(Interval(iv.begin, point, iv.data))
                insertions.add(Interval(point, iv.end, iv.data))
        self.difference_update(hitlist)
        self.update(insertions)

    def clear(self):
        """
        Empties the tree.

        Completes in O(1) tine.
        """
        self.__init__()

    def find_nested(self):
        """
        Returns a dictionary mapping parent intervals to sets of 
        intervals overlapped by and contained in the parent.
        
        Completes in O(n^2) time.
        :rtype: dict of [Interval, set of Interval]
        """
        result = {}

        def add_if_nested():
            if parent.contains_interval(child):
                if parent not in result:
                    result[parent] = set()
                result[parent].add(child)

        long_ivs = sorted(self.all_intervals,
                          key=Interval.length,
                          reverse=True)
        for i, parent in enumerate(long_ivs):
            for child in long_ivs[i + 1:]:
                add_if_nested()
        return result

    def overlaps(self, begin, end=None):
        """
        Returns whether some interval in the tree overlaps the given
        point or range.
        
        Completes in O(r*log n) time, where r is the size of the
        search range.
        :rtype: bool
        """
        if end is not None:
            return self.overlaps_range(begin, end)
        elif isinstance(begin, Number):
            return self.overlaps_point(begin)
        else:
            return self.overlaps_range(begin.begin, begin.end)

    def overlaps_point(self, p):
        """
        Returns whether some interval in the tree overlaps p.
        
        Completes in O(log n) time.
        :rtype: bool
        """
        if self.is_empty():
            return False
        return bool(self.top_node.contains_point(p))

    def overlaps_range(self, begin, end):
        """
        Returns whether some interval in the tree overlaps the given
        range. Returns False if given a null interval over which to
        test.
        
        Completes in O(r*log n) time, where r is the range length and n
        is the table size.
        :rtype: bool
        """
        if self.is_empty():
            return False
        elif begin >= end:
            return False
        elif self.overlaps_point(begin):
            return True
        return any(
            self.overlaps_point(bound) for bound in self.boundary_table
            if begin < bound < end)

    def split_overlaps(self):
        """
        Finds all intervals with overlapping ranges and splits them
        along the range boundaries.
        
        Completes in worst-case O(n^2*log n) time (many interval 
        boundaries are inside many intervals), best-case O(n*log n)
        time (small number of overlaps << n per interval).
        """
        if not self:
            return
        if len(self.boundary_table) == 2:
            return

        bounds = sorted(self.boundary_table)  # get bound locations

        new_ivs = set()
        for lbound, ubound in zip(bounds[:-1], bounds[1:]):
            for iv in self[lbound]:
                new_ivs.add(Interval(lbound, ubound, iv.data))

        self.__init__(new_ivs)

    def merge_overlaps(self, data_reducer=None, data_initializer=None):
        """
        Finds all intervals with overlapping ranges and merges them
        into a single interval. If provided, uses data_reducer and
        data_initializer with similar semantics to Python's built-in
        reduce(reducer_func[, initializer]), as follows:

        If data_reducer is set to a function, combines the data
        fields of the Intervals with
            current_reduced_data = data_reducer(current_reduced_data, new_data)
        If data_reducer is None, the merged Interval's data
        field will be set to None, ignoring all the data fields
        of the merged Intervals.

        On encountering the first Interval to merge, if
        data_initializer is None (default), uses the first
        Interval's data field as the first value for
        current_reduced_data. If data_initializer is not None,
        current_reduced_data is set to a shallow copy of
        data_initiazer created with
            copy.copy(data_initializer).#

        Completes in O(n*logn).
        """
        if not self:
            return

        sorted_intervals = sorted(self.all_intervals)  # get sorted intervals
        merged = []
        # use mutable object to allow new_series() to modify it
        current_reduced = [None]
        higher = None  # iterating variable, which new_series() needs access to

        def new_series():
            if data_initializer is None:
                current_reduced[0] = higher.data
                merged.append(higher)
                return
            else:  # data_initializer is not None
                current_reduced[0] = copy(data_initializer)
                #current_reduced[0] = data_initializer()
                #print "in new_series:"
                #print current_reduced
                current_reduced[0] = data_reducer(current_reduced[0],
                                                  higher.data)
                #current_reduced[0] = data_reducer(data_initializer, higher.data)
                merged.append(
                    Interval(higher.begin, higher.end, current_reduced[0]))

        for higher in sorted_intervals:
            if merged:  # series already begun
                lower = merged[-1]
                if higher.begin <= lower.end:  # should merge
                    upper_bound = max(lower.end, higher.end)
                    if data_reducer is not None:
                        current_reduced[0] = data_reducer(
                            current_reduced[0], higher.data)
                    else:  # annihilate the data, since we don't know how to merge it
                        current_reduced[0] = None
                    merged[-1] = Interval(lower.begin, upper_bound,
                                          current_reduced[0])
                else:
                    new_series()
            else:  # not merged; is first of Intervals to merge
                new_series()

        self.__init__(merged)

    def merge_equals(self, data_reducer=None, data_initializer=None):
        """
        Finds all intervals with equal ranges and merges them
        into a single interval. If provided, uses data_reducer and
        data_initializer with similar semantics to Python's built-in
        reduce(reducer_func[, initializer]), as follows:

        If data_reducer is set to a function, combines the data
        fields of the Intervals with
            current_reduced_data = data_reducer(current_reduced_data, new_data)
        If data_reducer is None, the merged Interval's data
        field will be set to None, ignoring all the data fields
        of the merged Intervals.

        On encountering the first Interval to merge, if
        data_initializer is None (default), uses the first
        Interval's data field as the first value for
        current_reduced_data. If data_initializer is not None,
        current_reduced_data is set to a shallow copy of
        data_initiazer created with
            copy.copy(data_initializer).

        Completes in O(n*logn).
        """
        if not self:
            return

        sorted_intervals = sorted(self.all_intervals)  # get sorted intervals
        merged = []
        # use mutable object to allow new_series() to modify it
        current_reduced = [None]
        higher = None  # iterating variable, which new_series() needs access to

        def new_series():
            if data_initializer is None:
                current_reduced[0] = higher.data
                merged.append(higher)
                return
            else:  # data_initializer is not None
                current_reduced[0] = copy(data_initializer)
                current_reduced[0] = data_reducer(current_reduced[0],
                                                  higher.data)
                merged.append(
                    Interval(higher.begin, higher.end, current_reduced[0]))

        for higher in sorted_intervals:
            if merged:  # series already begun
                lower = merged[-1]
                if higher.range_matches(lower):  # should merge
                    upper_bound = max(lower.end, higher.end)
                    if data_reducer is not None:
                        current_reduced[0] = data_reducer(
                            current_reduced[0], higher.data)
                    else:  # annihilate the data, since we don't know how to merge it
                        current_reduced[0] = None
                    merged[-1] = Interval(lower.begin, upper_bound,
                                          current_reduced[0])
                else:
                    new_series()
            else:  # not merged; is first of Intervals to merge
                new_series()

        self.__init__(merged)

    def items(self):
        """
        Constructs and returns a set of all intervals in the tree. 
        
        Completes in O(n) time.
        :rtype: set of Interval
        """
        return set(self.all_intervals)

    def is_empty(self):
        """
        Returns whether the tree is empty.
        
        Completes in O(1) time.
        :rtype: bool
        """
        return 0 == len(self)

    def search(self, begin, end=None, strict=False):
        """
        Returns a set of all intervals overlapping the given range. Or,
        if strict is True, returns the set of all intervals fully
        contained in the range [begin, end].
        
        Completes in O(m + k*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * k = size of the search range (this is 1 for a point)
        :rtype: set of Interval
        """
        root = self.top_node
        if not root:
            return set()
        if end is None:
            try:
                iv = begin
                return self.search(iv.begin, iv.end, strict=strict)
            except:
                return root.search_point(begin, set())
        elif begin >= end:
            return set()
        else:
            result = root.search_point(begin, set())

            boundary_table = self.boundary_table
            bound_begin = boundary_table.bisect_left(begin)
            bound_end = boundary_table.bisect_left(
                end)  # exclude final end bound
            result.update(
                root.search_overlap(
                    # slice notation is slightly slower
                    boundary_table.keys()[index]
                    for index in xrange(bound_begin, bound_end)))

            # TODO: improve strict search to use node info instead of less-efficient filtering
            if strict:
                result = set(iv for iv in result
                             if iv.begin >= begin and iv.end <= end)
            return result

    def begin(self):
        """
        Returns the lower bound of the first interval in the tree.
        
        Completes in O(1) time.
        """
        if not self.boundary_table:
            return 0
        return self.boundary_table.keys()[0]

    def end(self):
        """
        Returns the upper bound of the last interval in the tree.
        
        Completes in O(1) time.
        """
        if not self.boundary_table:
            return 0
        return self.boundary_table.keys()[-1]

    def range(self):
        """
        Returns a minimum-spanning Interval that encloses all the
        members of this IntervalTree. If the tree is empty, returns
        null Interval.
        :rtype: Interval
        """
        return Interval(self.begin(), self.end())

    def span(self):
        """
        Returns the length of the minimum-spanning Interval that
        encloses all the members of this IntervalTree. If the tree
        is empty, return 0.
        """
        if not self:
            return 0
        return self.end() - self.begin()

    def print_structure(self, tostring=False):
        """
        ## FOR DEBUGGING ONLY ##
        Pretty-prints the structure of the tree. 
        If tostring is true, prints nothing and returns a string.
        :rtype: None or str
        """
        if self.top_node:
            return self.top_node.print_structure(tostring=tostring)
        else:
            result = "<empty IntervalTree>"
            if not tostring:
                print(result)
            else:
                return result

    def verify(self):
        """
        ## FOR DEBUGGING ONLY ##
        Checks the table to ensure that the invariants are held.
        """
        if self.all_intervals:
            ## top_node.all_children() == self.all_intervals
            try:
                assert self.top_node.all_children() == self.all_intervals
            except AssertionError as e:
                print(
                    'Error: the tree and the membership set are out of sync!')
                tivs = set(self.top_node.all_children())
                print('top_node.all_children() - all_intervals:')
                try:
                    pprint
                except NameError:
                    from pprint import pprint
                pprint(tivs - self.all_intervals)
                print('all_intervals - top_node.all_children():')
                pprint(self.all_intervals - tivs)
                raise e

            ## All members are Intervals
            for iv in self:
                assert isinstance(iv, Interval), (
                    "Error: Only Interval objects allowed in IntervalTree:"
                    " {0}".format(iv))

            ## No null intervals
            for iv in self:
                assert not iv.is_null(), (
                    "Error: Null Interval objects not allowed in IntervalTree:"
                    " {0}".format(iv))

            ## Reconstruct boundary_table
            bound_check = {}
            for iv in self:
                if iv.begin in bound_check:
                    bound_check[iv.begin] += 1
                else:
                    bound_check[iv.begin] = 1
                if iv.end in bound_check:
                    bound_check[iv.end] += 1
                else:
                    bound_check[iv.end] = 1

            ## Reconstructed boundary table (bound_check) ==? boundary_table
            assert set(self.boundary_table.keys()) == set(bound_check.keys()),\
                'Error: boundary_table is out of sync with ' \
                'the intervals in the tree!'

            # For efficiency reasons this should be iteritems in Py2, but we
            # don't care much for efficiency in debug methods anyway.
            for key, val in self.boundary_table.items():
                assert bound_check[key] == val, \
                    'Error: boundary_table[{0}] should be {1},' \
                    ' but is {2}!'.format(
                        key, bound_check[key], val)

            ## Internal tree structure
            self.top_node.verify(set())
        else:
            ## Verify empty tree
            assert not self.boundary_table, \
                "Error: boundary table should be empty!"
            assert self.top_node is None, \
                "Error: top_node isn't None!"

    def score(self, full_report=False):
        """
        Returns a number between 0 and 1, indicating how suboptimal the tree
        is. The lower, the better. Roughly, this number represents the
        fraction of flawed Intervals in the tree.
        :rtype: float
        """
        if len(self) <= 2:
            return 0.0

        n = len(self)
        m = self.top_node.count_nodes()

        def s_center_score():
            """
            Returns a normalized score, indicating roughly how many times
            intervals share s_center with other intervals. Output is full-scale
            from 0 to 1.
            :rtype: float
            """
            raw = n - m
            maximum = n - 1
            return raw / float(maximum)

        report = {
            "depth": self.top_node.depth_score(n, m),
            "s_center": s_center_score(),
        }
        cumulative = max(report.values())
        report["_cumulative"] = cumulative
        if full_report:
            return report
        return cumulative

    def __getitem__(self, index):
        """
        Returns a set of all intervals overlapping the given index or 
        slice.
        
        Completes in O(k * log(n) + m) time, where:
          * n = size of the tree
          * m = number of matches
          * k = size of the search range (this is 1 for a point)
        :rtype: set of Interval
        """
        try:
            start, stop = index.start, index.stop
            if start is None:
                start = self.begin()
                if stop is None:
                    return set(self)
            if stop is None:
                stop = self.end()
            return self.search(start, stop)
        except AttributeError:
            return self.search(index)

    def __setitem__(self, index, value):
        """
        Adds a new interval to the tree. A shortcut for
        add(Interval(index.start, index.stop, value)).
        
        If an identical Interval object with equal range and data 
        already exists, does nothing.
        
        Completes in O(log n) time.
        """
        self.addi(index.start, index.stop, value)

    def __delitem__(self, point):
        """
        Delete all items overlapping point.
        """
        self.remove_overlap(point)

    def __contains__(self, item):
        """
        Returns whether item exists as an Interval in the tree.
        This method only returns True for exact matches; for
        overlaps, see the overlaps() method.
        
        Completes in O(1) time.
        :rtype: bool
        """
        # Removed point-checking code; it might trick the user into
        # thinking that this is O(1), which point-checking isn't.
        #if isinstance(item, Interval):
        return item in self.all_intervals
        #else:
        #    return self.contains_point(item)

    def containsi(self, begin, end, data=None):
        """
        Shortcut for (Interval(begin, end, data) in tree).
        
        Completes in O(1) time.
        :rtype: bool
        """
        return Interval(begin, end, data) in self

    def __iter__(self):
        """
        Returns an iterator over all the intervals in the tree.
        
        Completes in O(1) time.
        :rtype: collections.Iterable[Interval]
        """
        return self.all_intervals.__iter__()

    iter = __iter__

    def __len__(self):
        """
        Returns how many intervals are in the tree.
        
        Completes in O(1) time.
        :rtype: int
        """
        return len(self.all_intervals)

    def __eq__(self, other):
        """
        Whether two IntervalTrees are equal.
        
        Completes in O(n) time if sizes are equal; O(1) time otherwise.
        :rtype: bool
        """
        return (isinstance(other, IntervalTree)
                and self.all_intervals == other.all_intervals)

    def __repr__(self):
        """
        :rtype: str
        """
        ivs = sorted(self)
        if not ivs:
            return "IntervalTree()"
        else:
            return "IntervalTree({0})".format(ivs)

    __str__ = __repr__

    def __reduce__(self):
        """
        For pickle-ing.
        :rtype: tuple
        """
        return IntervalTree, (sorted(self.all_intervals), )
Пример #36
0
def getTokensA(pitches_occurences, durations_occurences, offsets_occurences,
               max_tokens):
    tokens = SortedDict()
    tokens[0] = 0
    counter = 1

    for pitch, occurences in pitches_occurences.items():
        if pitch not in tokens.keys():
            tokens[pitch] = counter
            counter += 1

    last_pitch = counter - 1

    tokens[32] = counter
    counter += 1
    tokens[45] = counter
    counter += 1
    tokens[90] = counter
    counter += 1
    tokens[120] = counter
    counter += 1

    last_velocity = counter - 1

    remaining_tokens = int(max_tokens - (counter + 1))
    current_counter = counter
    # durations
    durations = list(durations_occurences.keys())
    i = 0

    canEmplace = True
    while canEmplace:
        currentDuration = abs(durations[i])
        i += 1
        if currentDuration <= 1.0:
            if currentDuration not in tokens:
                tokens[currentDuration] = counter
                counter += 1
            if len(tokens) >= current_counter + (remaining_tokens / 2):
                canEmplace = False

    last_duration = counter - 1
    current_counter = counter
    canEmplace = True
    # offsets
    i = 0
    offsets = list(offsets_occurences.keys())
    while canEmplace:
        currentOffset = abs(offsets[i])
        i += 1
        if currentOffset <= 2.0:
            if currentOffset not in tokens:
                tokens[currentOffset] = counter
                counter += 1
            if len(tokens) >= current_counter + (remaining_tokens / 2):
                canEmplace = False

    last_offset = counter - 1

    token_cutoffs = {}
    token_cutoffs["last_pitch"] = last_pitch
    token_cutoffs["last_velocity"] = last_velocity
    token_cutoffs["last_duration"] = last_duration
    token_cutoffs["last_offset"] = last_offset

    return tokens, token_cutoffs
Пример #37
0
    def validate(self,
                 protocol_name,
                 subset='development',
                 aggregate=False,
                 every=1,
                 start=0):

        # prepare paths
        validate_dir = self.VALIDATE_DIR.format(train_dir=self.train_dir_,
                                                protocol=protocol_name)
        validate_txt = self.VALIDATE_TXT.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')
        validate_png = self.VALIDATE_PNG.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')
        validate_eps = self.VALIDATE_EPS.format(
            validate_dir=validate_dir,
            subset=subset,
            aggregate='aggregate.' if aggregate else '')

        # create validation directory
        mkdir_p(validate_dir)

        # Build validation set
        if aggregate:
            X, n, y = self._validation_set_z(protocol_name, subset=subset)
        else:
            X, y = self._validation_set_y(protocol_name, subset=subset)

        # list of equal error rates, and epoch to process
        eers, epoch = SortedDict(), start

        desc_format = ('Best EER = {best_eer:.2f}% @ epoch #{best_epoch:d} ::'
                       ' EER = {eer:.2f}% @ epoch #{epoch:d} :')

        progress_bar = tqdm(unit='epoch')

        with open(validate_txt, mode='w') as fp:

            # watch and evaluate forever
            while True:

                # last completed epochs
                completed_epochs = self.get_epochs(self.train_dir_) - 1

                if completed_epochs < epoch:
                    time.sleep(60)
                    continue

                # if last completed epoch has already been processed
                # go back to first epoch that hasn't been processed yet
                process_epoch = epoch if completed_epochs in eers \
                                      else completed_epochs

                # do not validate this epoch if it has been done before...
                if process_epoch == epoch and epoch in eers:
                    epoch += every
                    progress_bar.update(every)
                    continue

                weights_h5 = LoggingCallback.WEIGHTS_H5.format(
                    log_dir=self.train_dir_, epoch=process_epoch)

                # this is needed for corner case when training is started from
                # an epoch > 0
                if not isfile(weights_h5):
                    time.sleep(60)
                    continue

                # sleep 5 seconds to let the checkpoint callback finish
                time.sleep(5)

                embedding = keras.models.load_model(
                    weights_h5, custom_objects=CUSTOM_OBJECTS, compile=False)

                if aggregate:

                    def embed(X):
                        func = K.function([
                            embedding.get_layer(name='input').input,
                            K.learning_phase()
                        ], [embedding.get_layer(name='internal').output])
                        return func([X, 0])[0]
                else:
                    embed = embedding.predict

                # embed all validation sequences
                fX = embed(X)

                if aggregate:
                    indices = np.hstack([[0], np.cumsum(n)])
                    fX = np.stack([
                        np.sum(np.sum(fX[i:j], axis=0), axis=0)
                        for i, j in pairwise(indices)
                    ])
                    fX = l2_normalize(fX)

                # compute pairwise distances
                y_pred = pdist(fX, metric=self.approach_.metric)
                # compute pairwise groundtruth
                y_true = pdist(y, metric='chebyshev') < 1
                # estimate equal error rate
                _, _, _, eer = det_curve(y_true, y_pred, distances=True)
                eers[process_epoch] = eer

                # save equal error rate to file
                fp.write(
                    self.VALIDATE_TXT_TEMPLATE.format(epoch=process_epoch,
                                                      eer=eer))
                fp.flush()

                # keep track of best epoch so far
                best_epoch = eers.iloc[np.argmin(eers.values())]
                best_eer = eers[best_epoch]

                progress_bar.set_description(
                    desc_format.format(epoch=process_epoch,
                                       eer=100 * eer,
                                       best_epoch=best_epoch,
                                       best_eer=100 * best_eer))

                # plot
                fig = plt.figure()
                plt.plot(eers.keys(), eers.values(), 'b')
                plt.plot([best_epoch], [best_eer], 'bo')
                plt.plot([eers.iloc[0], eers.iloc[-1]], [best_eer, best_eer],
                         'k--')
                plt.grid(True)
                plt.xlabel('epoch')
                plt.ylabel('EER on {subset}'.format(subset=subset))
                TITLE = '{best_eer:.5g} @ epoch #{best_epoch:d}'
                title = TITLE.format(best_eer=best_eer,
                                     best_epoch=best_epoch,
                                     subset=subset)
                plt.title(title)
                plt.tight_layout()
                plt.savefig(validate_png, dpi=75)
                plt.savefig(validate_eps)
                plt.close(fig)

                # go to next epoch
                if epoch == process_epoch:
                    epoch += every
                    progress_bar.update(every)
                else:
                    progress_bar.update(0)

        progress_bar.close()
Пример #38
0
class SpaceBinaryTree:
    '''
    Atrributes
    '''
    free_pages = 0
    sorted_dict = None
    '''
    Methods
    '''

    # Constructor
    def __init__(self, free_pages):
        #keep track of the number of pages free
        self.free_pages = free_pages
        # define the attributes
        #creating the list of indexes to keep track.
        free_pages_available = [x for x in range(free_pages)]
        # self.head = TreeNode(node_left=None, node_right=None, size=size, free_pages=free_pages)
        self.sorted_dict = SortedDict(
            {free_pages: TreeNode(free_pages=[free_pages_available])})

    def set_empty_space(self, num_of_slots, free_pages):
        if DEBUG:
            logger.debug("[space_binary_tree] Inside set empty_space")
            logger.debug(
                "[space_binary_tree] About to set the num_of_free_pages")
            logger.debug(
                "[space_binary_tree] number of slots to free: {}".format(
                    num_of_slots))
            logger.debug("[space_binary_tree] free pages: {}".format(
                self.free_pages))

        self.set_num_free_pages(ADD, num_of_slots)

        if DEBUG:
            logger.debug(
                "[space_binary_tree] Already set the num_of_free_pages")
            logger.debug(
                "[space_binary_tree] number of slots to free: {}".format(
                    num_of_slots))
            logger.debug("[space_binary_tree] free pages: {}".format(
                self.free_pages))

        if self.sorted_dict.get(num_of_slots):
            if DEBUG:
                logger.debug(
                    "[space_binary_tree] Number of slots exist in sorted_dict")
                # logger.debug("[space_binary_tree] free pages given: {}".format(free_pages))
                logger.debug("[space_binary_tree] sorted_dict: {}".format(
                    self.sorted_dict.get(num_of_slots).size))
                self.sorted_dict.get(num_of_slots).print_free_pages()

            self.sorted_dict.get(num_of_slots).set_free_pages(free_pages)

        # If the key does not exist, then we create a new node and add
        # the list to a new node
        else:
            if DEBUG:
                logger.debug("[space_binary_tree] New entry")
                logger.debug("[space_binary_tree] free pages given:")

            self.sorted_dict[num_of_slots] = TreeNode(size=num_of_slots,
                                                      free_pages=[free_pages])

    # else:
    # print("Index Error, the node has no list remaining")

    def set_num_free_pages(self, mode, value):
        if mode == ADD:
            self.free_pages += value
        elif mode == REMOVE:
            self.free_pages -= value
        else:
            raise

    def get_total_free_pages(self):
        return self.free_pages

    def get_available_space(self, number_of_chunks):
        #number of spaces is available in sorted dict
        '''
        :param number_of_chunks:
        :return:
        '''
        '''
        The data structure can find the exact number of pages requested
        if it does find a spot that fits, it can allocate that space
        '''
        if DEBUG:
            print("[SpaceBinaryTree] Number of chunks needed = {}".format(
                number_of_chunks))
            print("[SpaceBinaryTree] doe sit exist? = {}".format(
                self.sorted_dict.get(number_of_chunks)))

        #print("[SpaceBinaryTree] resultado del if = {}".format(self.sorted_dict.get(number_of_chunks)))
        if self.sorted_dict.get(number_of_chunks):
            if DEBUG:
                print(
                    "[SpaceBinaryTree] inside if number of chunks matches available"
                )

            temp_node = self.sorted_dict.pop(number_of_chunks)
            ret_list = temp_node.get_free_pages()
            self.set_num_free_pages(REMOVE, number_of_chunks)

            if not temp_node.is_node_empty():
                self.sorted_dict[number_of_chunks] = temp_node

            # ret_list = self.sorted_dict.get(number_of_chunks).get_free_pages()

            if DEBUG:
                for i, index in enumerate(ret_list):
                    print("[SpaceBinaryTree] return list item: {}".format(
                        ret_list[i]))
                    if i >= PRINT_LIST_BREAK:
                        break

            return ret_list
        else:
            '''
            Either the chunck is largest or smallest than the chunks available.
            '''
            try:
                #print("[SpaceBinaryTree] chunks needed are not found in list, has to be built")
                # Get the largest space available from sorted dict
                largest_chuck = self.sorted_dict.keys()[-1]
                if DEBUG:
                    print(
                        "[SpaceBinaryTree] number of chunks needed is not available in the data structure"
                    )
                    print(
                        "[SpaceBinaryTree] we need to create a list of free pages"
                    )
                    print(
                        "[SpaceBinaryTree] Largest chunk available: {}".format(
                            largest_chuck))

                #print("[SpaceBinaryTree] largest chunk is = {}".format(largest_chuck))
                # Find if the number of chunks needed is lower than the largest available
                if number_of_chunks < largest_chuck:
                    if DEBUG:
                        print("[SpaceBinaryTree] number of chunks: {}".format(
                            number_of_chunks))
                        print(
                            "[SpaceBinaryTree] largest chunk available in data structure: {}"
                            .format(largest_chuck))
                        print(
                            "[SpaceBinaryTree] Largest chunk is bigger than number of chunks needed"
                        )
                    # if it is, we need to take one of those chunks and break it down
                    #first we get the whole node out of the sorted dict
                    if DEBUG:
                        print(
                            "[SpaceBinaryTree] Getting the largest chunk node into a temp location"
                        )
                    temp_node = self.sorted_dict.pop(largest_chuck)
                    self.set_num_free_pages(REMOVE, largest_chuck)
                    #print("[SpaceBinaryTree] Temp node popped out of list = {}".format(temp_node.get_size()))
                    #find if the node free pages list is not empty
                    #print("[SpaceBinaryTree] Temp node is empty = {}".format(temp_node.is_node_empty()))
                    if not temp_node.is_node_empty():
                        if DEBUG:
                            print("[SpaceBinaryTree] Node is not empty")
                        # pop the first list of pages available in the free_pages list.
                        # remember that free pages is a list if list of pages
                        # The structure is to be able to keep list of pages that are the same size under
                        # a single dictionary with key number of free pages
                        #print("[spaceBinaryTree] inside the not empty node")
                        # temp_node.print_free_pages()
                        temp_list = temp_node.get_free_pages()
                        if DEBUG:
                            if isinstance(temp_list, list):
                                for i, index in enumerate(temp_list):
                                    print(
                                        "[SpaceBinaryTree] return list from free_pages: {}"
                                        .format(temp_list[i]))
                                    if i >= PRINT_LIST_BREAK:
                                        break
                            else:
                                print(
                                    "[SpaceBinaryTree] Temp list is not a list!!!!!!!!"
                                )
                                print("[SpaceBinaryTree] Temp list is = {}".
                                      format(temp_list))
                        # Once we have the new list, we need to split it into the pages
                        # needed nad the remaining ones.
                        ret_list = temp_list[:number_of_chunks]
                        if DEBUG:
                            for i, index in enumerate(ret_list):
                                print("[SpaceBinaryTree] return list item: {}".
                                      format(ret_list[i]))
                                if i >= PRINT_LIST_BREAK:
                                    break
                        #keep track of the total number of pages available
                        # self.set_num_free_pages(REMOVE, number_of_chunks)
                        if DEBUG:
                            print(
                                "[SpaceBinaryTree] removing the number of spaces from the number of available spaces"
                            )
                            print(
                                "[SpaceBinaryTree] calling set_free_pages with {}: {}"
                                .format(REMOVE, number_of_chunks))
                        #print("Return list found = {}".format(len(ret_list)))
                        # Then we store the remaining list in an existing key
                        rem_list = temp_list[number_of_chunks:]
                        if DEBUG:
                            for i, index in enumerate(rem_list):
                                print(
                                    "[SpaceBinaryTree] remaining list item: {}"
                                    .format(rem_list[i]))
                                if i >= PRINT_LIST_BREAK:
                                    break
                        #print("Remaining list left = {}".format(len(rem_list)))
                        self.set_empty_space(len(rem_list), rem_list)
                        if DEBUG:
                            for i, index in enumerate(rem_list):
                                print(
                                    "[SpaceBinaryTree] remaining list set space: {}"
                                    .format(rem_list[i]))
                                if i >= PRINT_LIST_BREAK:
                                    break
                            print(
                                "[SpaceBinaryTree] setting empty space with length {}"
                                .format(len(rem_list)))
                        return ret_list
                    else:
                        err_message = "You got an empty node"
                        raise Exception(err_message)
                else:
                    if DEBUG:
                        print("[SpaceBinaryTree] number of chunks: {}".format(
                            number_of_chunks))
                        print(
                            "[SpaceBinaryTree] largest chunk available in data structure: {}"
                            .format(largest_chuck))
                        print(
                            "[SpaceBinaryTree] Largest chunk is smaller than number of chunks needed"
                        )
                    # Get the smallest space available from sorted dict
                    smallest_chuck = self.sorted_dict.keys()[0]
                    # Take smallest pieces and get as many until enough chunks
                    # meet the chunks needed
                    temp_node = self.sorted_dict.pop(smallest_chuck)
                    #number of chunks is larger than largest spot available.
                    enough_pages = False
                    ret_list = []
                    free_pages_needed = number_of_chunks
                    ret_list = None
                    # We are going to start looking at the smallest set of page lists
                    # and we are going to add as many pages needed for the chunks to fit in
                    while not enough_pages:
                        # The list of pages is smaller than needed at this point
                        if free_pages_needed[0] > smallest_chuck:
                            # Cycle through all the pages in the node and try to fill the gap
                            for lp in self.sorted_dict.get(
                                    smallest_chuck).get_all_free_pages():
                                if free_pages_needed > len(lp):
                                    ret_list = ret_list + lp
                                    free_pages_needed = free_pages_needed - len(
                                        ret_list)

                                else:
                                    ret_list = ret_list + lp[:free_pages_needed]
                                    free_pages_left = free_pages_needed - len(
                                        ret_list)
                                    #if there are any pages left after the split
                                    if free_pages_left > 0:
                                        self.set_empty_space(
                                            len(free_pages_needed),
                                            lp[free_pages_needed:])

                                    enough_pages = True

                                if enough_pages == True:
                                    self.sorted_dict.pop[0]
                                    self.set_num_free_pages(
                                        REMOVE, number_of_chunks)
                                    return ret_list
                                else:
                                    self.sorted_dict.pop(0)
                                    continue

            except Exception as e:
                print(str(e))
                raise
Пример #39
0
    def generate_graphs(self, show=False):
        filename = "task_arrival_{0}.png".format(self.workload_name)
        if os.path.isfile(os.path.join(self.folder, filename)):
            return filename

        fig = plt.figure(figsize=(9, 7))
        granularity_order = ["Second", "Minute", "Hour", "Day"]

        granularity_lambdas = {
            "Second": 1000,
            "Minute": 60 * 1000,
            "Hour": 60 * 60 * 1000,
            "Day": 60 * 60 * 24 * 1000,
        }

        plot_count = 0

        for granularity in granularity_order:
            task_arrivals = SortedDict()
            df = self.df.withColumn(
                'ts_submit',
                F.col('ts_submit') / granularity_lambdas[granularity])
            df = df.withColumn('ts_submit',
                               F.col('ts_submit').cast(T.LongType()))
            submit_times = df.groupBy("ts_submit").count().toPandas()

            for task in submit_times.itertuples():
                submit_time = int(task.ts_submit)

                if submit_time not in task_arrivals:
                    task_arrivals[submit_time] = 0

                task_arrivals[submit_time] += task.count

            ax = plt.subplot2grid(
                (2, 2), (int(math.floor(plot_count / 2)), (plot_count % 2)))
            if max(task_arrivals.keys()) >= 1:
                ax.plot(task_arrivals.keys(),
                        task_arrivals.values(),
                        color="black",
                        linewidth=1.0)
                ax.grid(True)
            else:
                ax.text(0.5,
                        0.5,
                        'Not available;\nTrace too small.',
                        horizontalalignment='center',
                        verticalalignment='center',
                        transform=ax.transAxes,
                        fontsize=16)
                ax.grid(False)

            # Rotates and right aligns the x labels, and moves the bottom of the
            # axes up to make room for them
            # fig.autofmt_xdate()

            ax.set_xlim(0)
            ax.set_ylim(0)

            ax.locator_params(nbins=3, axis='y')

            ax.margins(0.05)
            ax.tick_params(axis='both', which='major', labelsize=16)
            ax.tick_params(axis='both', which='minor', labelsize=14)

            ax.get_xaxis().get_offset_text().set_visible(False)
            formatter = ScalarFormatter(useMathText=True)
            formatter.set_powerlimits((-4, 5))
            ax.get_xaxis().set_major_formatter(formatter)
            fig.tight_layout(
            )  # Need to set this to be able to get the offset... for whatever reason
            offset_text = ax.get_xaxis().get_major_formatter().get_offset()

            ax.set_xlabel('Time{0} [{1}]'.format(
                f' {offset_text}' if len(offset_text) else "",
                granularity.lower()),
                          fontsize=18)
            ax.set_ylabel('Number of Tasks', fontsize=18)

            plot_count += 1

        fig.tight_layout()

        fig.savefig(os.path.join(self.folder, filename), dpi=600, format='png')
        if show:
            fig.show()

        return filename
Пример #40
0
class SpaceBinaryTree:
    '''
    Atrributes
    '''

    head = None
    sorted_dict = None
    '''
    Methods
    '''

    # Constructor
    def __init__(self, size, free_pages):
        # define the attributes
        free_pages = [x for x in range(free_pages)]
        self.head = TreeNode(node_left=None,
                             node_right=None,
                             size=size,
                             free_pages=free_pages)
        self.sorted_dict = SortedDict({
            size:
            TreeNode(node_left=None,
                     node_right=None,
                     size=size,
                     free_pages=[[free_pages]])
        })

    def set_empty_space(self, num_of_slots, free_pages):
        if self.sorted_dict.get(num_of_slots):
            self.sorted_dict.get(num_of_slots).set_free_pages(free_pages)
        # If the key does not exist, then we create a new node and add
        # the list to a new node
        else:
            self.sorted_dict[num_of_slots] = TreeNode(node_left=None,
                                                      node_right=None,
                                                      size=num_of_slots,
                                                      free_pages=[[free_pages]
                                                                  ])

    # else:
    # print("Index Error, the node has no list remaining")

    def get_available_space(self, number_of_chunks):
        #number of spaces is available in sorted dict
        '''
        :param number_of_chunks:
        :return:
        '''
        '''
        The data structure can find the exact number of pages requested
        if it does find a spot that fits, it can allocate that space
        '''
        print("[SpaceBinaryTree] Number of chunks needed = {}".format(
            number_of_chunks))
        #print("[SpaceBinaryTree] resultado del if = {}".format(self.sorted_dict.get(number_of_chunks)))
        if self.sorted_dict.get(number_of_chunks):
            print(
                "[SpaceBinaryTree] inside if number of chunks matches available"
            )
            ret_list = self.sorted_dict.get(number_of_chunks).free_pages.pop(0)
            return ret_list
        else:
            '''
            Either the chunck is largest or smallest than the chunks available.
            '''
            try:
                #print("[SpaceBinaryTree] chunks needed are not found in list, has to be built")
                # Get the largest space available from sorted dict
                largest_chuck = self.sorted_dict.keys()[-1]
                #print("[SpaceBinaryTree] largest chunk is = {}".format(largest_chuck))
                # Find if the number of chunks needed is lower than the largest available
                if number_of_chunks < largest_chuck:
                    # if it is, we need to take one of those chunks and break it down
                    #first we get the whole node out of the sorted dict
                    temp_node = self.sorted_dict.pop(largest_chuck)
                    #print("[SpaceBinaryTree] Temp node popped out of list = {}".format(temp_node.get_size()))
                    #find if the node free pages list is not empty
                    #print("[SpaceBinaryTree] Temp node is empty = {}".format(temp_node.is_node_empty()))
                    if not temp_node.is_node_empty():
                        # pop the first list of pages available in the free_pages list.
                        # remember that free pages is a list if list of pages
                        # The structure is to be able to keep list of pages that are the same size under
                        # a single dictionary with key number of free pages
                        #print("[spaceBinaryTree] inside the not empty node")
                        # temp_node.print_free_pages()
                        temp_list = temp_node.get_free_pages()
                        #print("[SpaceBinaryTree] Temp list is = {}".format(len(temp_list)))
                        # Once we have the new list, we need to split it into the pages
                        # needed nad the remaining ones.
                        ret_list = temp_list[:number_of_chunks]
                        #print("Return list found = {}".format(len(ret_list)))
                        # Then we store the remaining list in an existing key
                        rem_list = temp_list[number_of_chunks:]
                        #print("Remaining list left = {}".format(len(rem_list)))
                        self.set_empty_space(len(rem_list), rem_list)
                        return ret_list
                    else:
                        err_message = "You got an empty node"
                        raise Exception(err_message)
                else:
                    # Get the smallest space available from sorted dict
                    smallest_chuck = self.sorted_dict.keys()[0]
                    # Take smallest pieces and get as many until enough chunks
                    # meet the chunks needed
                    temp_node = self.sorted_dict.pop(smallest_chuck)
                    #number of chunks is larger than largest spot available.
                    enough_pages = False
                    ret_list = []
                    free_pages_needed = number_of_chunks
                    ret_list = None
                    # We are going to start looking at the smallest set of page lists
                    # and we are going to add as many pages needed for the chunks to fit in
                    while not enough_pages:
                        # The list of pages is smaller than needed at this point
                        if free_pages_needed[0] > smallest_chuck:
                            # Cycle through all the pages in the node and try to fill the gap
                            for lp in self.sorted_dict.get(
                                    smallest_chuck).get_all_free_pages():
                                if free_pages_needed > len(lp):
                                    ret_list = ret_list + lp
                                    free_pages_needed = free_pages_needed - len(
                                        ret_list)

                                else:
                                    ret_list = ret_list + lp[:free_pages_needed]
                                    free_pages_left = free_pages_needed - len(
                                        ret_list)
                                    #if there are any pages left after the split
                                    if free_pages_left > 0:
                                        self.set_empty_space(
                                            len(free_pages_needed),
                                            lp[free_pages_needed:])

                                    enough_pages = True

                                if enough_pages == True:
                                    self.sorted_dict.pop[0]
                                    return ret_list
                                else:
                                    self.sorted_dict.pop(0)
                                    continue

            except Exception as e:
                print(str(e))
                raise
Пример #41
0
class CacheStore(object):
    class CacheItem(object):
        __slots__ = ('valid', 'data')

        def __init__(self):
            self.valid = Event()
            self.data = None

    def __init__(self, key=None):
        self.lock = RLock()
        self.store = SortedDict(key)

    def __getitem__(self, item):
        return self.get(item)

    def put(self, key, data):
        with self.lock:
            try:
                item = self.store[key]
                item.data = data
                item.valid.set()
                return False
            except KeyError:
                item = self.CacheItem()
                item.data = data
                item.valid.set()
                self.store[key] = item
                return True

    def update(self, **kwargs):
        with self.lock:
            items = {}
            created = []
            updated = []
            for k, v in kwargs.items():
                items[k] = self.CacheItem()
                items[k].data = v
                items[k].valid.set()
                if k in self.store:
                    updated.append(k)
                else:
                    created.append(k)

            self.store.update(**items)
            return created, updated

    def update_one(self, key, **kwargs):
        with self.lock:
            item = self.get(key)
            if not item:
                return False

            for k, v in kwargs.items():
                set(item, k, v)

            self.put(key, item)
            return True

    def update_many(self, key, predicate, **kwargs):
        with self.lock:
            updated = []
            for k, v in self.itervalid():
                if predicate(v):
                    if self.update_one(k, **kwargs):
                        updated.append(key)

            return updated

    def get(self, key, default=None, timeout=None):
        item = self.store.get(key)
        if item:
            item.valid.wait(timeout)
            return item.data

        return default

    def remove(self, key):
        with self.lock:
            try:
                del self.store[key]
                return True
            except KeyError:
                return False

    def remove_many(self, keys):
        with self.lock:
            removed = []
            for key in keys:
                try:
                    del self.store[key]
                    removed.append(key)
                except KeyError:
                    pass

            return removed

    def clear(self):
        with self.lock:
            items = list(self.store.keys())
            self.store.clear()
            return items

    def exists(self, key):
        return key in self.store

    def rename(self, oldkey, newkey):
        with self.lock:
            obj = self.get(oldkey)
            obj['id'] = newkey
            self.put(newkey, obj)
            self.remove(oldkey)

    def is_valid(self, key):
        item = self.store.get(key)
        if item:
            return item.valid.is_set()

        return False

    def invalidate(self, key):
        with self.lock:
            item = self.store.get(key)
            if item:
                item.valid.clear()

    def itervalid(self):
        for key, value in list(self.store.items()):
            if value.valid.is_set():
                yield (key, value.data)

    def validvalues(self):
        for value in list(self.store.values()):
            if value.valid.is_set():
                yield value.data

    def remove_predicate(self, predicate):
        result = []
        for k, v in self.itervalid():
            if predicate(v):
                self.remove(k)
                result.append(k)

        return result

    def query(self, *filter, **params):
        return query(list(self.validvalues()), *filter, **params)
Пример #42
0
def read_swans(fileglob,
               ndays=None,
               int_freq=True,
               int_dir=False,
               dirorder=True,
               ntimes=None):
    """Read multiple SWAN ASCII files into single Dataset.

    Args:
        - fileglob (str, list): glob pattern specifying files to read.
        - ndays (float): number of days to keep from each file, choose None to
          keep entire period.
        - int_freq (ndarray, bool): frequency array for interpolating onto:
            - ndarray: 1d array specifying frequencies to interpolate onto.
            - True: logarithm array is constructed such that fmin=0.0418 Hz,
              fmax=0.71856 Hz, df=0.1f.
            - False: No interpolation performed in frequency space.
        - int_dir (ndarray, bool): direction array for interpolating onto:
            - ndarray: 1d array specifying directions to interpolate onto.
            - True: circular array is constructed such that dd=10 degrees.
            - False: No interpolation performed in direction space.
        - dirorder (bool): if True ensures directions are sorted.
        - ntimes (int): use it to read only specific number of times, useful
          for checking headers only.

    Returns:
        - dset (SpecDataset): spectra dataset object read from file with
          different sites and cycles concatenated along the 'site' and 'time'
          dimensions.

    Note:
        - If multiple cycles are provided, 'time' coordinate is replaced by
          'cycletime' multi-index coordinate.
        - If more than one cycle is prescribed from fileglob, each cycle must
          have same number of sites.
        - Either all or none of the spectra in fileglob must have tabfile
          associated to provide wind/depth data.
        - Concatenation is done with numpy arrays for efficiency.

    """
    swans = (sorted(fileglob) if isinstance(fileglob, list) else sorted(
        glob.glob(fileglob)))
    assert swans, "No SWAN file identified with fileglob %s" % (fileglob)

    # Default spectral basis for interpolating
    if int_freq is True:
        int_freq = [0.04118 * 1.1**n for n in range(31)]
    elif int_freq is False:
        int_freq = None
    if int_dir is True:
        int_dir = np.arange(0, 360, 10)
    elif int_dir is False:
        int_dir = None

    cycles = list()
    dsets = SortedDict()
    tabs = SortedDict()
    all_times = list()
    all_sites = SortedDict()
    all_lons = SortedDict()
    all_lats = SortedDict()
    deps = SortedDict()
    wspds = SortedDict()
    wdirs = SortedDict()

    for filename in swans:
        swanfile = SwanSpecFile(filename, dirorder=dirorder)
        times = swanfile.times
        lons = list(swanfile.x)
        lats = list(swanfile.y)
        sites = ([os.path.splitext(os.path.basename(filename))[0]]
                 if len(lons) == 1 else np.arange(len(lons)) + 1)
        freqs = swanfile.freqs
        dirs = swanfile.dirs

        if ntimes is None:
            spec_list = [s for s in swanfile.readall()]
        else:
            spec_list = [swanfile.read() for itime in range(ntimes)]

        # Read tab files for winds / depth
        if swanfile.is_tab:
            try:
                tab = read_tab(
                    swanfile.tabfile).rename(columns={"dep": attrs.DEPNAME})
                if len(swanfile.times) == tab.index.size:
                    if "X-wsp" in tab and "Y-wsp" in tab:
                        tab[attrs.WSPDNAME], tab[
                            attrs.WDIRNAME] = uv_to_spddir(tab["X-wsp"],
                                                           tab["Y-wsp"],
                                                           coming_from=True)
                else:
                    warnings.warn(
                        f"Times in {swanfile.filename} and {swanfile.tabfile} "
                        f"not consistent, not appending winds and depth")
                    tab = pd.DataFrame()
                tab = tab[list(
                    set(tab.columns).intersection(
                        (attrs.DEPNAME, attrs.WSPDNAME, attrs.WDIRNAME)))]
            except Exception as exc:
                warnings.warn(
                    f"Cannot parse depth and winds from {swanfile.tabfile}:\n{exc}"
                )
        else:
            tab = pd.DataFrame()

        # Shrinking times
        if ndays is not None:
            tend = times[0] + datetime.timedelta(days=ndays)
            if tend > times[-1]:
                raise OSError("Times in %s does not extend for %0.2f days" %
                              (filename, ndays))
            iend = times.index(min(times, key=lambda d: abs(d - tend)))
            times = times[0:iend + 1]
            spec_list = spec_list[0:iend + 1]
            tab = tab.loc[times[0]:tend] if tab is not None else tab

        spec_list = flatten_list(spec_list, [])

        # Interpolate spectra
        if int_freq is not None or int_dir is not None:
            spec_list = [
                interp_spec(spec, freqs, dirs, int_freq, int_dir)
                for spec in spec_list
            ]
            freqs = int_freq if int_freq is not None else freqs
            dirs = int_dir if int_dir is not None else dirs

        # Appending
        try:
            arr = np.array(spec_list).reshape(len(times), len(sites),
                                              len(freqs), len(dirs))
            cycle = times[0]
            if cycle not in dsets:
                dsets[cycle] = [arr]
                tabs[cycle] = [tab]
                all_sites[cycle] = sites
                all_lons[cycle] = lons
                all_lats[cycle] = lats
                all_times.append(times)
                nsites = 1
            else:
                dsets[cycle].append(arr)
                tabs[cycle].append(tab)
                all_sites[cycle].extend(sites)
                all_lons[cycle].extend(lons)
                all_lats[cycle].extend(lats)
                nsites += 1
        except Exception:
            if len(spec_list) != arr.shape[0]:
                raise OSError(
                    "Time length in %s (%i) does not match previous files (%i), "
                    "cannot concatenate",
                    (filename, len(spec_list), arr.shape[0]),
                )
            else:
                raise
        swanfile.close()

    cycles = dsets.keys()

    # Ensuring sites are consistent across cycles
    sites = all_sites[cycle]
    lons = all_lons[cycle]
    lats = all_lats[cycle]
    for site, lon, lat in zip(all_sites.values(), all_lons.values(),
                              all_lats.values()):
        if ((list(site) != list(sites)) or (list(lon) != list(lons))
                or (list(lat) != list(lats))):
            raise OSError(
                "Inconsistent sites across cycles in glob pattern provided")

    # Ensuring consistent tabs
    cols = set([
        frozenset(tabs[cycle][n].columns) for cycle in cycles
        for n in range(len(tabs[cycle]))
    ])
    if len(cols) > 1:
        raise OSError(
            "Inconsistent tab files, ensure either all or none of the spectra have "
            "associated tabfiles and columns are consistent")

    # Concat sites
    for cycle in cycles:
        dsets[cycle] = np.concatenate(dsets[cycle], axis=1)
        deps[cycle] = (np.vstack([
            tab[attrs.DEPNAME].values for tab in tabs[cycle]
        ]).T if attrs.DEPNAME in tabs[cycle][0] else None)
        wspds[cycle] = (np.vstack([
            tab[attrs.WSPDNAME].values for tab in tabs[cycle]
        ]).T if attrs.WSPDNAME in tabs[cycle][0] else None)
        wdirs[cycle] = (np.vstack([
            tab[attrs.WDIRNAME].values for tab in tabs[cycle]
        ]).T if attrs.WDIRNAME in tabs[cycle][0] else None)

    time_sizes = [dsets[cycle].shape[0] for cycle in cycles]

    # Concat cycles
    if len(dsets) > 1:
        dsets = np.concatenate(dsets.values(), axis=0)
        deps = (np.concatenate(deps.values(), axis=0)
                if attrs.DEPNAME in tabs[cycle][0] else None)
        wspds = (np.concatenate(wspds.values(), axis=0)
                 if attrs.WSPDNAME in tabs[cycle][0] else None)
        wdirs = (np.concatenate(wdirs.values(), axis=0)
                 if attrs.WDIRNAME in tabs[cycle][0] else None)
    else:
        dsets = dsets[cycle]
        deps = deps[cycle] if attrs.DEPNAME in tabs[cycle][0] else None
        wspds = wspds[cycle] if attrs.WSPDNAME in tabs[cycle][0] else None
        wdirs = wdirs[cycle] if attrs.WDIRNAME in tabs[cycle][0] else None

    # Creating dataset
    times = flatten_list(all_times, [])
    dsets = xr.DataArray(
        data=dsets,
        coords=OrderedDict((
            (attrs.TIMENAME, times),
            (attrs.SITENAME, sites),
            (attrs.FREQNAME, freqs),
            (attrs.DIRNAME, dirs),
        )),
        dims=(attrs.TIMENAME, attrs.SITENAME, attrs.FREQNAME, attrs.DIRNAME),
        name=attrs.SPECNAME,
    ).to_dataset()

    dsets[attrs.LATNAME] = xr.DataArray(data=lats,
                                        coords={attrs.SITENAME: sites},
                                        dims=[attrs.SITENAME])
    dsets[attrs.LONNAME] = xr.DataArray(data=lons,
                                        coords={attrs.SITENAME: sites},
                                        dims=[attrs.SITENAME])

    if wspds is not None:
        dsets[attrs.WSPDNAME] = xr.DataArray(
            data=wspds,
            dims=[attrs.TIMENAME, attrs.SITENAME],
            coords=OrderedDict(
                ((attrs.TIMENAME, times), (attrs.SITENAME, sites))),
        )
        dsets[attrs.WDIRNAME] = xr.DataArray(
            data=wdirs,
            dims=[attrs.TIMENAME, attrs.SITENAME],
            coords=OrderedDict(
                ((attrs.TIMENAME, times), (attrs.SITENAME, sites))),
        )
    if deps is not None:
        dsets[attrs.DEPNAME] = xr.DataArray(
            data=deps,
            dims=[attrs.TIMENAME, attrs.SITENAME],
            coords=OrderedDict(
                ((attrs.TIMENAME, times), (attrs.SITENAME, sites))),
        )

    # Setting multi-index
    if len(cycles) > 1:
        dsets = dsets.rename({attrs.TIMENAME: "cycletime"})
        cycletime = zip(
            [
                item
                for sublist in [[c] * t for c, t in zip(cycles, time_sizes)]
                for item in sublist
            ],
            dsets.cycletime.values,
        )
        dsets["cycletime"] = pd.MultiIndex.from_tuples(
            cycletime, names=[attrs.CYCLENAME, attrs.TIMENAME])
        dsets["cycletime"].attrs = attrs.ATTRS[attrs.TIMENAME]

    set_spec_attributes(dsets)
    if "dir" in dsets and len(dsets.dir) > 1:
        dsets[attrs.SPECNAME].attrs.update({
            "_units": "m^{2}.s.degree^{-1}",
            "_variable_name": "VaDens"
        })
    else:
        dsets[attrs.SPECNAME].attrs.update({
            "units": "m^{2}.s",
            "_units": "m^{2}.s",
            "_variable_name": "VaDens"
        })

    return dsets
Пример #43
0
class RangeCounter(object):
    def __init__(self, k):
        self.k = k
        from sortedcontainers import SortedDict
        self.ranges = SortedDict()

    def process(self, transaction_info):
        for get_range in transaction_info.get_ranges:
            self._insert_range(get_range.key_range.start_key,
                               get_range.key_range.end_key)

    def _insert_range(self, start_key, end_key):
        keys = self.ranges.keys()
        if len(keys) == 0:
            self.ranges[start_key] = end_key, 1
            return

        start_pos = bisect_left(keys, start_key)
        end_pos = bisect_left(keys, end_key)
        #print("start_pos=%d, end_pos=%d" % (start_pos, end_pos))

        possible_intersection_keys = keys[max(0, start_pos -
                                              1):min(len(keys), end_pos + 1)]

        start_range_left = start_key

        for key in possible_intersection_keys:
            cur_end_key, cur_count = self.ranges[key]
            #logger.debug("key=%s, cur_end_key=%s, cur_count=%d, start_range_left=%s" % (key, cur_end_key, cur_count, start_range_left))
            if start_range_left < key:
                if end_key <= key:
                    self.ranges[start_range_left] = end_key, 1
                    return
                self.ranges[start_range_left] = key, 1
                start_range_left = key
            assert start_range_left >= key
            if start_range_left >= cur_end_key:
                continue

            # [key, start_range_left) = cur_count
            # if key == start_range_left this will get overwritten below
            self.ranges[key] = start_range_left, cur_count

            if end_key <= cur_end_key:
                # [start_range_left, end_key) = cur_count+1
                # [end_key, cur_end_key) = cur_count
                self.ranges[start_range_left] = end_key, cur_count + 1
                if end_key != cur_end_key:
                    self.ranges[end_key] = cur_end_key, cur_count
                start_range_left = end_key
                break
            else:
                # [start_range_left, cur_end_key) = cur_count+1
                self.ranges[start_range_left] = cur_end_key, cur_count + 1
                start_range_left = cur_end_key
            assert start_range_left <= end_key

        # there may be some range left
        if start_range_left < end_key:
            self.ranges[start_range_left] = end_key, 1

    def get_count_for_key(self, key):
        if key in self.ranges:
            return self.ranges[key][1]

        keys = self.ranges.keys()
        index = bisect_left(keys, key)
        if index == 0:
            return 0

        index_key = keys[index - 1]
        if index_key <= key < self.ranges[index_key][0]:
            return self.ranges[index_key][1]
        return 0

    def get_range_boundaries(self, shard_finder=None):
        total = sum([count for _, (_, count) in self.ranges.items()])
        range_size = total // self.k
        output_range_counts = []

        def add_boundary(start, end, count):
            if shard_finder:
                shard_count = shard_finder.get_shard_count(start, end)
                if shard_count == 1:
                    addresses = shard_finder.get_addresses_for_key(start)
                else:
                    addresses = None
                output_range_counts.append(
                    (start, end, count, shard_count, addresses))
            else:
                output_range_counts.append((start, end, count, None, None))

        this_range_start_key = None
        count_this_range = 0
        for (start_key, (end_key, count)) in self.ranges.items():
            if not this_range_start_key:
                this_range_start_key = start_key
            count_this_range += count
            if count_this_range >= range_size:
                add_boundary(this_range_start_key, end_key, count_this_range)
                count_this_range = 0
                this_range_start_key = None
        if count_this_range > 0:
            add_boundary(this_range_start_key, end_key, count_this_range)

        return output_range_counts
Пример #44
0
class AttributeSet:
    """The AttributeSet class that represents an attribute set."""
    def __init__(self, attributes: Optional[Iterable[Attribute]] = None):
        """Initialize the AttributeSet object with the attributes.

        Args:
            attributes: The attributes that compose the attribute set if set.

        Raises:
            DuplicateAttributeId: Two attributes share the same id.
        """
        # Maintain a sorted dictionary linking the attributes id to the
        # attribute objects
        self._id_to_attr = SortedDict()
        if attributes:
            for attribute in attributes:
                self.add(attribute)

    def __iter__(self) -> Iterator:
        """Give the iterator for the AttributeSet to get the attributes.

        Returns:
            An iterator that iterates over the Attribute objects that compose
            the attribute set.
        """
        return iter(self._id_to_attr.values())

    def __repr__(self) -> str:
        """Provide a string representation of the attribute set.

        Returns:
            A string representation of the attribute set.
        """
        attribute_list = ', '.join(
            str(attr) for attr in self._id_to_attr.values())
        return f'{self.__class__.__name__}([{attribute_list}])'

    @property
    def attribute_names(self) -> List[str]:
        """Give the names of the attributes of this attribute set (read only).

        The attribute names are sorted in function of the attribute ids.

        Returns:
            The name of the attributes of this attribute set as a list of str.
        """
        return list(attribute.name for attribute in self._id_to_attr.values())

    @property
    def attribute_ids(self) -> List[int]:
        """Give the ids of the attributes of this attribute set (read only).

        Returns:
            The ids of the attributes of this set as a sorted list of integers.
        """
        return list(self._id_to_attr.keys())

    def add(self, attribute: Attribute):
        """Add an attribute to this attribute set if it is not already present.

        Args:
            attribute: The attribute to add.

        Raises:
            DuplicateAttributeId: An attribute with the same id as the
                                  attribute that is added already exists.
        """
        if attribute.attribute_id in self._id_to_attr:
            raise DuplicateAttributeId('An attribute with the same id as '
                                       f'{attribute} already exists.')
        self._id_to_attr[attribute.attribute_id] = attribute

    def remove(self, attribute: Attribute):
        """Remove an attribute from this attribute set.

        Args:
            attribute: The attribute to remove.

        Raises:
            KeyError: The attribute is not present in this attribute set.
        """
        if attribute.attribute_id not in self._id_to_attr:
            raise KeyError(f'{attribute} is not among the attributes.')
        del self._id_to_attr[attribute.attribute_id]

    def __hash__(self) -> int:
        """Give the hash of an attribute set: the hash of its attributes.

        Returns:
            The hash of an attribute set as the hash of its frozen attributes.
        """
        return hash(frozenset(self.attribute_ids))

    def __eq__(self, other_attr_set: 'AttributeSet') -> bool:
        """Compare two attribute sets, equal if the attributes correspond.

        Args:
            other_attr_set: The other attribute set to which the attribute set
                            is compared with.

        Returns:
            The two attribute sets are equal: they share the same attributes.
        """
        return (isinstance(other_attr_set, self.__class__)
                and hash(self) == hash(other_attr_set))

    def __contains__(self, attribute: Attribute) -> bool:
        """Check if the attribute is in the attribute set.

        Args:
            attribute: The attribute that is checked whether it is in this set.

        Returns:
            The attribute is in the attribute set.
        """
        return attribute.attribute_id in self._id_to_attr

    def __len__(self) -> int:
        """Give the size of this attribute set as the number of attributes.

        Returns:
            The number of attributes in this attribute set.
        """
        return len(self._id_to_attr)

    def issuperset(self, other_attribute_set: 'AttributeSet') -> bool:
        """Check if the attribute set is a superset of the one in parameters.

        Args:
            other_attribute_set: The attribute set for which we check whether
                                 the attribute set is a superset of.

        Returns:
            The attribute set is a superset of the other attribute set.
        """
        self_attribute_ids_set = frozenset(self.attribute_ids)
        other_attribute_ids_set = frozenset(other_attribute_set.attribute_ids)
        return self_attribute_ids_set.issuperset(other_attribute_ids_set)

    def issubset(self, other_attribute_set: 'AttributeSet') -> bool:
        """Check if the attribute set is a subset of the one in parameters.

        Args:
            other_attribute_set: The attribute set for which we check whether
                                 the attribute set is a subset of.

        Returns:
            The attribute set is a subset of the other attribute set.
        """
        self_attribute_ids_set = frozenset(self.attribute_ids)
        other_attribute_ids_set = frozenset(other_attribute_set.attribute_ids)
        return self_attribute_ids_set.issubset(other_attribute_ids_set)

    def get_attribute_by_id(self, attribute_id: int) -> Attribute:
        """Give an attribute by its id.

        Args:
            attribute_id: The id of the attribute to retrieve.

        Raises:
            KeyError: The attribute is not present in this attribute set.
        """
        if attribute_id not in self._id_to_attr:
            raise KeyError(f'No attribute with the id {attribute_id}.')
        return self._id_to_attr[attribute_id]

    def get_attribute_by_name(self, name: str) -> Attribute:
        """Give an attribute by its name.

        Args:
            name: The name of the attribute to retrieve.

        Raises:
            KeyError: The attribute is not present in this attribute set.
        """
        for attribute in self._id_to_attr.values():
            if attribute.name == name:
                return attribute
        raise KeyError(f'No attribute is named {name}.')
Пример #45
0
class Replica(HasActionQueue, MessageProcessor):
    def __init__(self,
                 node: 'plenum.server.node.Node',
                 instId: int,
                 isMaster: bool = False):
        """
        Create a new replica.

        :param node: Node on which this replica is located
        :param instId: the id of the protocol instance the replica belongs to
        :param isMaster: is this a replica of the master protocol instance
        """
        HasActionQueue.__init__(self)
        self.stats = Stats(TPCStat)

        self.config = getConfig()

        routerArgs = [(ReqDigest, self._preProcessReqDigest)]

        for r in [PrePrepare, Prepare, Commit]:
            routerArgs.append((r, self.processThreePhaseMsg))

        routerArgs.append((Checkpoint, self.processCheckpoint))
        routerArgs.append((ThreePCState, self.process3PhaseState))

        self.inBoxRouter = Router(*routerArgs)

        self.threePhaseRouter = Router((PrePrepare, self.processPrePrepare),
                                       (Prepare, self.processPrepare),
                                       (Commit, self.processCommit))

        self.node = node
        self.instId = instId

        self.name = self.generateName(node.name, self.instId)

        self.outBox = deque()
        """
        This queue is used by the replica to send messages to its node. Replica
        puts messages that are consumed by its node
        """

        self.inBox = deque()
        """
        This queue is used by the replica to receive messages from its node.
        Node puts messages that are consumed by the replica
        """

        self.inBoxStash = deque()
        """
        If messages need to go back on the queue, they go here temporarily and
        are put back on the queue on a state change
        """

        self.isMaster = isMaster

        # Indicates name of the primary replica of this protocol instance.
        # None in case the replica does not know who the primary of the
        # instance is
        self._primaryName = None  # type: Optional[str]

        # Requests waiting to be processed once the replica is able to decide
        # whether it is primary or not
        self.postElectionMsgs = deque()

        # PRE-PREPAREs that are waiting to be processed but do not have the
        # corresponding request digest. Happens when replica has not been
        # forwarded the request by the node but is getting 3 phase messages.
        # The value is a list since a malicious entry might send PRE-PREPARE
        # with a different digest and since we dont have the request finalised,
        # we store all PRE-PPREPARES
        self.prePreparesPendingReqDigest = {
        }  # type: Dict[Tuple[str, int], List]

        # PREPAREs that are stored by non primary replica for which it has not
        #  got any PRE-PREPARE. Dictionary that stores a tuple of view no and
        #  prepare sequence number as key and a deque of PREPAREs as value.
        # This deque is attempted to be flushed on receiving every
        # PRE-PREPARE request.
        self.preparesWaitingForPrePrepare = {}
        # type: Dict[Tuple[int, int], deque]

        # COMMITs that are stored for which there are no PRE-PREPARE or PREPARE
        # received
        self.commitsWaitingForPrepare = {}
        # type: Dict[Tuple[int, int], deque]

        # Dictionary of sent PRE-PREPARE that are stored by primary replica
        # which it has broadcasted to all other non primary replicas
        # Key of dictionary is a 2 element tuple with elements viewNo,
        # pre-prepare seqNo and value is a tuple of Request Digest and time
        self.sentPrePrepares = {}
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], float]]

        # Dictionary of received PRE-PREPAREs. Key of dictionary is a 2
        # element tuple with elements viewNo, pre-prepare seqNo and value is
        # a tuple of Request Digest and time
        self.prePrepares = {}
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], float]]

        # Dictionary of received Prepare requests. Key of dictionary is a 2
        # element tuple with elements viewNo, seqNo and value is a 2 element
        # tuple containing request digest and set of sender node names(sender
        # replica names in case of multiple protocol instances)
        # (viewNo, seqNo) -> ((identifier, reqId), {senders})
        self.prepares = Prepares()
        # type: Dict[Tuple[int, int], Tuple[Tuple[str, int], Set[str]]]

        self.commits = Commits()  # type: Dict[Tuple[int, int],
        # Tuple[Tuple[str, int], Set[str]]]

        # Set of tuples to keep track of ordered requests. Each tuple is
        # (viewNo, ppSeqNo)
        self.ordered = OrderedSet()  # type: OrderedSet[Tuple[int, int]]

        # Dictionary to keep track of the which replica was primary during each
        # view. Key is the view no and value is the name of the primary
        # replica during that view
        self.primaryNames = {}  # type: Dict[int, str]

        # Holds msgs that are for later views
        self.threePhaseMsgsForLaterView = deque()
        # type: deque[(ThreePhaseMsg, str)]

        # Holds tuple of view no and prepare seq no of 3-phase messages it
        # received while it was not participating
        self.stashingWhileCatchingUp = set()  # type: Set[Tuple]

        # Commits which are not being ordered since commits with lower view
        # numbers and sequence numbers have not been ordered yet. Key is the
        # viewNo and value a map of pre-prepare sequence number to commit
        self.stashedCommitsForOrdering = {}  # type: Dict[int,
        # Dict[int, Commit]]

        self.checkpoints = SortedDict(lambda k: k[0])

        self.stashingWhileOutsideWaterMarks = deque()

        # Low water mark
        self._h = 0  # type: int

        # High water mark
        self.H = self._h + self.config.LOG_SIZE  # type: int

        self.lastPrePrepareSeqNo = self.h  # type: int

    @property
    def h(self) -> int:
        return self._h

    @h.setter
    def h(self, n):
        self._h = n
        self.H = self._h + self.config.LOG_SIZE

    @property
    def requests(self):
        return self.node.requests

    def shouldParticipate(self, viewNo: int, ppSeqNo: int):
        # Replica should only participating in the consensus process and the
        # replica did not stash any of this request's 3-phase request
        return self.node.isParticipating and (viewNo, ppSeqNo) \
                                             not in self.stashingWhileCatchingUp

    @staticmethod
    def generateName(nodeName: str, instId: int):
        """
        Create and return the name for a replica using its nodeName and
        instanceId.
         Ex: Alpha:1
        """
        return "{}:{}".format(nodeName, instId)

    @staticmethod
    def getNodeName(replicaName: str):
        return replicaName.split(":")[0]

    @property
    def isPrimary(self):
        """
        Is this node primary?

        :return: True if this node is primary, False otherwise
        """
        return self._primaryName == self.name if self._primaryName is not None \
            else None

    @property
    def primaryName(self):
        """
        Name of the primary replica of this replica's instance

        :return: Returns name if primary is known, None otherwise
        """
        return self._primaryName

    @primaryName.setter
    def primaryName(self, value: Optional[str]) -> None:
        """
        Set the value of isPrimary.

        :param value: the value to set isPrimary to
        """
        if not value == self._primaryName:
            self._primaryName = value
            self.primaryNames[self.viewNo] = value
            logger.debug("{} setting primaryName for view no {} to: {}".format(
                self, self.viewNo, value))
            logger.debug("{}'s primaryNames for views are: {}".format(
                self, self.primaryNames))
            self._stateChanged()

    def _stateChanged(self):
        """
        A series of actions to be performed when the state of this replica
        changes.

        - UnstashInBox (see _unstashInBox)
        """
        self._unstashInBox()
        if self.isPrimary is not None:
            # TODO handle suspicion exceptions here
            self.process3PhaseReqsQueue()
            # TODO handle suspicion exceptions here
            try:
                self.processPostElectionMsgs()
            except SuspiciousNode as ex:
                self.outBox.append(ex)
                self.discard(ex.msg, ex.reason, logger.warning)

    def _stashInBox(self, msg):
        """
        Stash the specified message into the inBoxStash of this replica.

        :param msg: the message to stash
        """
        self.inBoxStash.append(msg)

    def _unstashInBox(self):
        """
        Append the inBoxStash to the right of the inBox.
        """
        self.inBox.extend(self.inBoxStash)
        self.inBoxStash.clear()

    def __repr__(self):
        return self.name

    @property
    def f(self) -> int:
        """
        Return the number of Byzantine Failures that can be tolerated by this
        system. Equal to (N - 1)/3, where N is the number of nodes in the
        system.
        """
        return self.node.f

    @property
    def viewNo(self):
        """
        Return the current view number of this replica.
        """
        return self.node.viewNo

    def isPrimaryInView(self, viewNo: int) -> Optional[bool]:
        """
        Return whether a primary has been selected for this view number.
        """
        return self.primaryNames[viewNo] == self.name

    def isMsgForLaterView(self, msg):
        """
        Return whether this request's view number is greater than the current
        view number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo > self.viewNo

    def isMsgForCurrentView(self, msg):
        """
        Return whether this request's view number is equal to the current view
        number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo == self.viewNo

    def isMsgForPrevView(self, msg):
        """
        Return whether this request's view number is less than the current view
        number of this replica.
        """
        viewNo = getattr(msg, "viewNo", None)
        return viewNo < self.viewNo

    def isPrimaryForMsg(self, msg) -> Optional[bool]:
        """
        Return whether this replica is primary if the request's view number is
        equal this replica's view number and primary has been selected for
        the current view.
        Return None otherwise.

        :param msg: message
        """
        if self.isMsgForLaterView(msg):
            self.discard(
                msg, "Cannot get primary status for a request for a later "
                "view {}. Request is {}".format(self.viewNo, msg),
                logger.error)
        else:
            return self.isPrimary if self.isMsgForCurrentView(msg) \
                else self.isPrimaryInView(msg.viewNo)

    def isMsgFromPrimary(self, msg, sender: str) -> bool:
        """
        Return whether this message was from primary replica
        :param msg:
        :param sender:
        :return:
        """
        if self.isMsgForLaterView(msg):
            logger.error("{} cannot get primary for a request for a later "
                         "view. Request is {}".format(self, msg))
        else:
            return self.primaryName == sender if self.isMsgForCurrentView(
                msg) else self.primaryNames[msg.viewNo] == sender

    def _preProcessReqDigest(self, rd: ReqDigest) -> None:
        """
        Process request digest if this replica is not a primary, otherwise stash
        the message into the inBox.

        :param rd: the client Request Digest
        """
        if self.isPrimary is not None:
            self.processReqDigest(rd)
        else:
            logger.debug(
                "{} stashing request digest {} since it does not know "
                "its primary status".format(self, (rd.identifier, rd.reqId)))
            self._stashInBox(rd)

    def serviceQueues(self, limit=None):
        """
        Process `limit` number of messages in the inBox.

        :param limit: the maximum number of messages to process
        :return: the number of messages successfully processed
        """
        # TODO should handle SuspiciousNode here
        r = self.inBoxRouter.handleAllSync(self.inBox, limit)
        r += self._serviceActions()
        return r
        # Messages that can be processed right now needs to be added back to the
        # queue. They might be able to be processed later

    def processPostElectionMsgs(self):
        """
        Process messages waiting for the election of a primary replica to
        complete.
        """
        while self.postElectionMsgs:
            msg = self.postElectionMsgs.popleft()
            logger.debug("{} processing pended msg {}".format(self, msg))
            self.dispatchThreePhaseMsg(*msg)

    def process3PhaseReqsQueue(self):
        """
        Process the 3 phase requests from the queue whose view number is equal
        to the current view number of this replica.
        """
        unprocessed = deque()
        while self.threePhaseMsgsForLaterView:
            request, sender = self.threePhaseMsgsForLaterView.popleft()
            logger.debug("{} processing pended 3 phase request: {}".format(
                self, request))
            # If the request is for a later view dont try to process it but add
            # it back to the queue.
            if self.isMsgForLaterView(request):
                unprocessed.append((request, sender))
            else:
                self.processThreePhaseMsg(request, sender)
        self.threePhaseMsgsForLaterView = unprocessed

    @property
    def quorum(self) -> int:
        r"""
        Return the quorum of this RBFT system. Equal to :math:`2f + 1`.
        Return None if `f` is not yet determined.
        """
        return self.node.quorum

    def dispatchThreePhaseMsg(self, msg: ThreePhaseMsg, sender: str) -> Any:
        """
        Create a three phase request to be handled by the threePhaseRouter.

        :param msg: the ThreePhaseMsg to dispatch
        :param sender: the name of the node that sent this request
        """
        senderRep = self.generateName(sender, self.instId)
        if self.isPpSeqNoAcceptable(msg.ppSeqNo):
            try:
                self.threePhaseRouter.handleSync((msg, senderRep))
            except SuspiciousNode as ex:
                self.node.reportSuspiciousNodeEx(ex)
        else:
            logger.debug("{} stashing 3 phase message {} since ppSeqNo {} is "
                         "not between {} and {}".format(
                             self, msg, msg.ppSeqNo, self.h, self.H))
            self.stashingWhileOutsideWaterMarks.append((msg, sender))

    def processReqDigest(self, rd: ReqDigest):
        """
        Process a request digest. Works only if this replica has decided its
        primary status.

        :param rd: the client request digest to process
        """
        self.stats.inc(TPCStat.ReqDigestRcvd)
        if self.isPrimary is False:
            self.dequeuePrePrepare(rd.identifier, rd.reqId)
        else:
            self.doPrePrepare(rd)

    def processThreePhaseMsg(self, msg: ThreePhaseMsg, sender: str):
        """
        Process a 3-phase (pre-prepare, prepare and commit) request.
        Dispatch the request only if primary has already been decided, otherwise
        stash it.

        :param msg: the Three Phase message, one of PRE-PREPARE, PREPARE,
            COMMIT
        :param sender: name of the node that sent this message
        """
        # Can only proceed further if it knows whether its primary or not
        if self.isMsgForLaterView(msg):
            self.threePhaseMsgsForLaterView.append((msg, sender))
            logger.debug(
                "{} pended received 3 phase request for a later view: "
                "{}".format(self, msg))
        else:
            if self.isPrimary is None:
                self.postElectionMsgs.append((msg, sender))
                logger.debug("Replica {} pended request {} from {}".format(
                    self, msg, sender))
            else:
                self.dispatchThreePhaseMsg(msg, sender)

    def processPrePrepare(self, pp: PrePrepare, sender: str):
        """
        Validate and process the PRE-PREPARE specified.
        If validation is successful, create a PREPARE and broadcast it.

        :param pp: a prePrepareRequest
        :param sender: name of the node that sent this message
        """
        key = (pp.viewNo, pp.ppSeqNo)
        logger.debug("{} Receiving PRE-PREPARE{} at {} from {}".format(
            self, key, time.perf_counter(), sender))
        if self.canProcessPrePrepare(pp, sender):
            if not self.node.isParticipating:
                self.stashingWhileCatchingUp.add(key)
            self.addToPrePrepares(pp)
            logger.info("{} processed incoming PRE-PREPARE{}".format(
                self, key))

    def tryPrepare(self, pp: PrePrepare):
        """
        Try to send the Prepare message if the PrePrepare message is ready to
        be passed into the Prepare phase.
        """
        if self.canSendPrepare(pp):
            self.doPrepare(pp)
        else:
            logger.debug("{} cannot send PREPARE".format(self))

    def processPrepare(self, prepare: Prepare, sender: str) -> None:
        """
        Validate and process the PREPARE specified.
        If validation is successful, create a COMMIT and broadcast it.

        :param prepare: a PREPARE msg
        :param sender: name of the node that sent the PREPARE
        """
        # TODO move this try/except up higher
        logger.debug("{} received PREPARE{} from {}".format(
            self, (prepare.viewNo, prepare.ppSeqNo), sender))
        try:
            if self.isValidPrepare(prepare, sender):
                self.addToPrepares(prepare, sender)
                self.stats.inc(TPCStat.PrepareRcvd)
                logger.debug("{} processed incoming PREPARE {}".format(
                    self, (prepare.viewNo, prepare.ppSeqNo)))
            else:
                # TODO let's have isValidPrepare throw an exception that gets
                # handled and possibly logged higher
                logger.warning(
                    "{} cannot process incoming PREPARE".format(self))
        except SuspiciousNode as ex:
            self.node.reportSuspiciousNodeEx(ex)

    def processCommit(self, commit: Commit, sender: str) -> None:
        """
        Validate and process the COMMIT specified.
        If validation is successful, return the message to the node.

        :param commit: an incoming COMMIT message
        :param sender: name of the node that sent the COMMIT
        """
        logger.debug("{} received COMMIT {} from {}".format(
            self, commit, sender))
        if self.isValidCommit(commit, sender):
            self.stats.inc(TPCStat.CommitRcvd)
            self.addToCommits(commit, sender)
            logger.debug("{} processed incoming COMMIT{}".format(
                self, (commit.viewNo, commit.ppSeqNo)))

    def tryCommit(self, prepare: Prepare):
        """
        Try to commit if the Prepare message is ready to be passed into the
        commit phase.
        """
        if self.canCommit(prepare):
            self.doCommit(prepare)
        else:
            logger.debug("{} not yet able to send COMMIT".format(self))

    def tryOrder(self, commit: Commit):
        """
        Try to order if the Commit message is ready to be ordered.
        """
        canOrder, reason = self.canOrder(commit)
        if canOrder:
            logger.debug("{} returning request to node".format(self))
            self.tryOrdering(commit)
        else:
            logger.trace("{} cannot return request to node: {}".format(
                self, reason))

    def doPrePrepare(self, reqDigest: ReqDigest) -> None:
        """
        Broadcast a PRE-PREPARE to all the replicas.

        :param reqDigest: a tuple with elements identifier, reqId, and digest
        """
        if not self.node.isParticipating:
            logger.error("Non participating node is attempting PRE-PREPARE. "
                         "This should not happen.")
            return

        if self.lastPrePrepareSeqNo == self.H:
            logger.debug("{} stashing PRE-PREPARE {} since outside greater "
                         "than high water mark {}".format(
                             self, (self.viewNo, self.lastPrePrepareSeqNo + 1),
                             self.H))
            self.stashingWhileOutsideWaterMarks.append(reqDigest)
            return
        self.lastPrePrepareSeqNo += 1
        tm = time.time() * 1000
        logger.debug("{} Sending PRE-PREPARE {} at {}".format(
            self, (self.viewNo, self.lastPrePrepareSeqNo),
            time.perf_counter()))
        prePrepareReq = PrePrepare(self.instId, self.viewNo,
                                   self.lastPrePrepareSeqNo, *reqDigest, tm)
        self.sentPrePrepares[self.viewNo,
                             self.lastPrePrepareSeqNo] = (reqDigest.key, tm)
        self.send(prePrepareReq, TPCStat.PrePrepareSent)

    def doPrepare(self, pp: PrePrepare):
        logger.debug("{} Sending PREPARE {} at {}".format(
            self, (pp.viewNo, pp.ppSeqNo), time.perf_counter()))
        prepare = Prepare(self.instId, pp.viewNo, pp.ppSeqNo, pp.digest,
                          pp.ppTime)
        self.send(prepare, TPCStat.PrepareSent)
        self.addToPrepares(prepare, self.name)

    def doCommit(self, p: Prepare):
        """
        Create a commit message from the given Prepare message and trigger the
        commit phase
        :param p: the prepare message
        """
        logger.debug("{} Sending COMMIT{} at {}".format(
            self, (p.viewNo, p.ppSeqNo), time.perf_counter()))
        commit = Commit(self.instId, p.viewNo, p.ppSeqNo, p.digest, p.ppTime)
        self.send(commit, TPCStat.CommitSent)
        self.addToCommits(commit, self.name)

    def canProcessPrePrepare(self, pp: PrePrepare, sender: str) -> bool:
        """
        Decide whether this replica is eligible to process a PRE-PREPARE,
        based on the following criteria:

        - this replica is non-primary replica
        - the request isn't in its list of received PRE-PREPAREs
        - the request is waiting to for PRE-PREPARE and the digest value matches

        :param pp: a PRE-PREPARE msg to process
        :param sender: the name of the node that sent the PRE-PREPARE msg
        :return: True if processing is allowed, False otherwise
        """
        # TODO: Check whether it is rejecting PRE-PREPARE from previous view
        # PRE-PREPARE should not be sent from non primary
        if not self.isMsgFromPrimary(pp, sender):
            raise SuspiciousNode(sender, Suspicions.PPR_FRM_NON_PRIMARY, pp)

        # A PRE-PREPARE is being sent to primary
        if self.isPrimaryForMsg(pp) is True:
            raise SuspiciousNode(sender, Suspicions.PPR_TO_PRIMARY, pp)

        # A PRE-PREPARE is sent that has already been received
        if (pp.viewNo, pp.ppSeqNo) in self.prePrepares:
            raise SuspiciousNode(sender, Suspicions.DUPLICATE_PPR_SENT, pp)

        key = (pp.identifier, pp.reqId)
        if not self.requests.isFinalised(key):
            self.enqueuePrePrepare(pp, sender)
            return False

        # A PRE-PREPARE is sent that does not match request digest
        if self.requests.digest(key) != pp.digest:
            raise SuspiciousNode(sender, Suspicions.PPR_DIGEST_WRONG, pp)

        return True

    def addToPrePrepares(self, pp: PrePrepare) -> None:
        """
        Add the specified PRE-PREPARE to this replica's list of received
        PRE-PREPAREs.

        :param pp: the PRE-PREPARE to add to the list
        """
        key = (pp.viewNo, pp.ppSeqNo)
        self.prePrepares[key] = \
            ((pp.identifier, pp.reqId), pp.ppTime)
        self.dequeuePrepares(*key)
        self.dequeueCommits(*key)
        self.stats.inc(TPCStat.PrePrepareRcvd)
        self.tryPrepare(pp)

    def hasPrepared(self, request) -> bool:
        return self.prepares.hasPrepareFrom(request, self.name)

    def canSendPrepare(self, request) -> bool:
        """
        Return whether the request identified by (identifier, requestId) can
        proceed to the Prepare step.

        :param request: any object with identifier and requestId attributes
        """
        return self.shouldParticipate(request.viewNo, request.ppSeqNo) \
            and not self.hasPrepared(request) \
            and self.requests.isFinalised((request.identifier,
                                           request.reqId))

    def isValidPrepare(self, prepare: Prepare, sender: str) -> bool:
        """
        Return whether the PREPARE specified is valid.

        :param prepare: the PREPARE to validate
        :param sender: the name of the node that sent the PREPARE
        :return: True if PREPARE is valid, False otherwise
        """
        key = (prepare.viewNo, prepare.ppSeqNo)
        primaryStatus = self.isPrimaryForMsg(prepare)

        ppReqs = self.sentPrePrepares if primaryStatus else self.prePrepares

        # If a non primary replica and receiving a PREPARE request before a
        # PRE-PREPARE request, then proceed

        # PREPARE should not be sent from primary
        if self.isMsgFromPrimary(prepare, sender):
            raise SuspiciousNode(sender, Suspicions.PR_FRM_PRIMARY, prepare)

        # If non primary replica
        if primaryStatus is False:
            if self.prepares.hasPrepareFrom(prepare, sender):
                raise SuspiciousNode(sender, Suspicions.DUPLICATE_PR_SENT,
                                     prepare)
            # If PRE-PREPARE not received for the PREPARE, might be slow network
            if key not in ppReqs:
                self.enqueuePrepare(prepare, sender)
                return False
            elif prepare.digest != self.requests.digest(ppReqs[key][0]):
                raise SuspiciousNode(sender, Suspicions.PR_DIGEST_WRONG,
                                     prepare)
            elif prepare.ppTime != ppReqs[key][1]:
                raise SuspiciousNode(sender, Suspicions.PR_TIME_WRONG, prepare)
            else:
                return True
        # If primary replica
        else:
            if self.prepares.hasPrepareFrom(prepare, sender):
                raise SuspiciousNode(sender, Suspicions.DUPLICATE_PR_SENT,
                                     prepare)
            # If PRE-PREPARE was not sent for this PREPARE, certainly
            # malicious behavior
            elif key not in ppReqs:
                raise SuspiciousNode(sender, Suspicions.UNKNOWN_PR_SENT,
                                     prepare)
            elif prepare.digest != self.requests.digest(ppReqs[key][0]):
                raise SuspiciousNode(sender, Suspicions.PR_DIGEST_WRONG,
                                     prepare)
            elif prepare.ppTime != ppReqs[key][1]:
                raise SuspiciousNode(sender, Suspicions.PR_TIME_WRONG, prepare)
            else:
                return True

    def addToPrepares(self, prepare: Prepare, sender: str):
        self.prepares.addVote(prepare, sender)
        self.tryCommit(prepare)

    def hasCommitted(self, request) -> bool:
        return self.commits.hasCommitFrom(
            ThreePhaseKey(request.viewNo, request.ppSeqNo), self.name)

    def canCommit(self, prepare: Prepare) -> bool:
        """
        Return whether the specified PREPARE can proceed to the Commit
        step.

        Decision criteria:

        - If this replica has got just 2f PREPARE requests then commit request.
        - If less than 2f PREPARE requests then probably there's no consensus on
            the request; don't commit
        - If more than 2f then already sent COMMIT; don't commit

        :param prepare: the PREPARE
        """
        return self.shouldParticipate(prepare.viewNo, prepare.ppSeqNo) and \
            self.prepares.hasQuorum(prepare, self.f) and \
            not self.hasCommitted(prepare)

    def isValidCommit(self, commit: Commit, sender: str) -> bool:
        """
        Return whether the COMMIT specified is valid.

        :param commit: the COMMIT to validate
        :return: True if `request` is valid, False otherwise
        """
        primaryStatus = self.isPrimaryForMsg(commit)
        ppReqs = self.sentPrePrepares if primaryStatus else self.prePrepares
        key = (commit.viewNo, commit.ppSeqNo)
        if key not in ppReqs:
            self.enqueueCommit(commit, sender)
            return False

        if (key not in self.prepares
                and key not in self.preparesWaitingForPrePrepare):
            logger.debug(
                "{} rejecting COMMIT{} due to lack of prepares".format(
                    self, key))
            # raise SuspiciousNode(sender, Suspicions.UNKNOWN_CM_SENT, commit)
            return False
        elif self.commits.hasCommitFrom(commit, sender):
            raise SuspiciousNode(sender, Suspicions.DUPLICATE_CM_SENT, commit)
        elif commit.digest != self.getDigestFor3PhaseKey(ThreePhaseKey(*key)):
            raise SuspiciousNode(sender, Suspicions.CM_DIGEST_WRONG, commit)
        elif key in ppReqs and commit.ppTime != ppReqs[key][1]:
            raise SuspiciousNode(sender, Suspicions.CM_TIME_WRONG, commit)
        else:
            return True

    def addToCommits(self, commit: Commit, sender: str):
        """
        Add the specified COMMIT to this replica's list of received
        commit requests.

        :param commit: the COMMIT to add to the list
        :param sender: the name of the node that sent the COMMIT
        """
        self.commits.addVote(commit, sender)
        self.tryOrder(commit)

    def hasOrdered(self, viewNo, ppSeqNo) -> bool:
        return (viewNo, ppSeqNo) in self.ordered

    def canOrder(self, commit: Commit) -> Tuple[bool, Optional[str]]:
        """
        Return whether the specified commitRequest can be returned to the node.

        Decision criteria:

        - If have got just 2f+1 Commit requests then return request to node
        - If less than 2f+1 of commit requests then probably don't have
            consensus on the request; don't return request to node
        - If more than 2f+1 then already returned to node; don't return request
            to node

        :param commit: the COMMIT
        """
        if not self.commits.hasQuorum(commit, self.f):
            return False, "no quorum: {} commits where f is {}".\
                          format(commit, self.f)

        if self.hasOrdered(commit.viewNo, commit.ppSeqNo):
            return False, "already ordered"

        if not self.isNextInOrdering(commit):
            viewNo, ppSeqNo = commit.viewNo, commit.ppSeqNo
            if viewNo not in self.stashedCommitsForOrdering:
                self.stashedCommitsForOrdering[viewNo] = {}
            self.stashedCommitsForOrdering[viewNo][ppSeqNo] = commit
            self.startRepeating(self.orderStashedCommits, 2)
            return False, "stashing {} since out of order".\
                format(commit)

        return True, None

    def isNextInOrdering(self, commit: Commit):
        viewNo, ppSeqNo = commit.viewNo, commit.ppSeqNo
        if self.ordered and self.ordered[-1] == (viewNo, ppSeqNo - 1):
            return True
        for (v, p) in self.commits:
            if v < viewNo:
                # Have commits from previous view that are unordered.
                # TODO: Question: would commits be always ordered, what if
                # some are never ordered and its fine, go to PBFT.
                return False
            if v == viewNo and p < ppSeqNo and (v, p) not in self.ordered:
                # If unordered commits are found with lower ppSeqNo then this
                # cannot be ordered.
                return False

        # TODO: Revisit PBFT paper, how to make sure that last request of the
        # last view has been ordered? Need change in `VIEW CHANGE` mechanism.
        # Somehow view change needs to communicate what the last request was.
        # Also what if some COMMITs were completely missed in the same view
        return True

    def orderStashedCommits(self):
        # TODO: What if the first few commits were out of order and stashed?
        # `self.ordered` would be empty
        if self.ordered:
            lastOrdered = self.ordered[-1]
            vToRemove = set()
            for v in self.stashedCommitsForOrdering:
                if v < lastOrdered[0] and self.stashedCommitsForOrdering[v]:
                    raise RuntimeError(
                        "{} found commits from previous view {}"
                        " that were not ordered but last ordered"
                        " is {}".format(self, v, lastOrdered))
                pToRemove = set()
                for p, commit in self.stashedCommitsForOrdering[v].items():
                    if (v == lastOrdered[0] and lastOrdered == (v, p - 1)) or \
                            (v > lastOrdered[0] and
                                self.isLowestCommitInView(commit)):
                        logger.debug("{} ordering stashed commit {}".format(
                            self, commit))
                        if self.tryOrdering(commit):
                            lastOrdered = (v, p)
                            pToRemove.add(p)

                for p in pToRemove:
                    del self.stashedCommitsForOrdering[v][p]
                if not self.stashedCommitsForOrdering[v]:
                    vToRemove.add(v)

            for v in vToRemove:
                del self.stashedCommitsForOrdering[v]

            # if self.stashedCommitsForOrdering:
            #     self._schedule(self.orderStashedCommits, 2)
            if not self.stashedCommitsForOrdering:
                self.stopRepeating(self.orderStashedCommits)

    def isLowestCommitInView(self, commit):
        # TODO: Assumption: This assumes that at least one commit that was sent
        #  for any request by any node has been received in the view of this
        # commit
        ppSeqNos = []
        for v, p in self.commits:
            if v == commit.viewNo:
                ppSeqNos.append(p)
        return min(ppSeqNos) == commit.ppSeqNo if ppSeqNos else True

    def tryOrdering(self, commit: Commit) -> None:
        """
        Attempt to send an ORDERED request for the specified COMMIT to the
        node.

        :param commit: the COMMIT message
        """
        key = (commit.viewNo, commit.ppSeqNo)
        logger.debug("{} trying to order COMMIT{}".format(self, key))
        reqKey = self.getReqKeyFrom3PhaseKey(key)  # type: Tuple
        digest = self.getDigestFor3PhaseKey(key)
        if not digest:
            logger.error(
                "{} did not find digest for {}, request key {}".format(
                    self, key, reqKey))
            return
        self.doOrder(*key, *reqKey, digest, commit.ppTime)
        return True

    def doOrder(self, viewNo, ppSeqNo, identifier, reqId, digest, ppTime):
        key = (viewNo, ppSeqNo)
        self.addToOrdered(*key)
        ordered = Ordered(self.instId, viewNo, identifier, reqId, ppTime)
        # TODO: Should not order or add to checkpoint while syncing
        # 3 phase state.
        self.send(ordered, TPCStat.OrderSent)
        if key in self.stashingWhileCatchingUp:
            self.stashingWhileCatchingUp.remove(key)
        logger.debug("{} ordered request {}".format(self, (viewNo, ppSeqNo)))
        self.addToCheckpoint(ppSeqNo, digest)

    def processCheckpoint(self, msg: Checkpoint, sender: str):
        if self.checkpoints:
            seqNo = msg.seqNo
            _, firstChk = self.firstCheckPoint
            if firstChk.isStable:
                if firstChk.seqNo == seqNo:
                    self.discard(msg,
                                 reason="Checkpoint already stable",
                                 logMethod=logger.debug)
                    return
                if firstChk.seqNo > seqNo:
                    self.discard(msg,
                                 reason="Higher stable checkpoint present",
                                 logMethod=logger.debug)
                    return
            for state in self.checkpoints.values():
                if state.seqNo == seqNo:
                    if state.digest == msg.digest:
                        state.receivedDigests[sender] = msg.digest
                        break
                    else:
                        logger.error("{} received an incorrect digest {} for "
                                     "checkpoint {} from {}".format(
                                         self, msg.digest, seqNo, sender))
                        return
            if len(state.receivedDigests) == 2 * self.f:
                self.markCheckPointStable(msg.seqNo)
        else:
            self.discard(msg,
                         reason="No checkpoints present to tally",
                         logMethod=logger.warn)

    def _newCheckpointState(self, ppSeqNo, digest) -> CheckpointState:
        s, e = ppSeqNo, ppSeqNo + self.config.CHK_FREQ - 1
        logger.debug("{} adding new checkpoint state for {}".format(
            self, (s, e)))
        state = CheckpointState(ppSeqNo, [
            digest,
        ], None, {}, False)
        self.checkpoints[s, e] = state
        return state

    def addToCheckpoint(self, ppSeqNo, digest):
        for (s, e) in self.checkpoints.keys():
            if s <= ppSeqNo <= e:
                state = self.checkpoints[s, e]  # type: CheckpointState
                state.digests.append(digest)
                state = updateNamedTuple(state, seqNo=ppSeqNo)
                self.checkpoints[s, e] = state
                break
        else:
            state = self._newCheckpointState(ppSeqNo, digest)
            s, e = ppSeqNo, ppSeqNo + self.config.CHK_FREQ

        if len(state.digests) == self.config.CHK_FREQ:
            state = updateNamedTuple(state,
                                     digest=serialize(state.digests),
                                     digests=[])
            self.checkpoints[s, e] = state
            self.send(
                Checkpoint(self.instId, self.viewNo, ppSeqNo, state.digest))

    def markCheckPointStable(self, seqNo):
        previousCheckpoints = []
        for (s, e), state in self.checkpoints.items():
            if e == seqNo:
                state = updateNamedTuple(state, isStable=True)
                self.checkpoints[s, e] = state
                break
            else:
                previousCheckpoints.append((s, e))
        else:
            logger.error("{} could not find {} in checkpoints".format(
                self, seqNo))
            return
        self.h = seqNo
        for k in previousCheckpoints:
            logger.debug("{} removing previous checkpoint {}".format(self, k))
            self.checkpoints.pop(k)
        self.gc(seqNo)
        logger.debug("{} marked stable checkpoint {}".format(self, (s, e)))
        self.processStashedMsgsForNewWaterMarks()

    def gc(self, tillSeqNo):
        logger.debug("{} cleaning up till {}".format(self, tillSeqNo))
        tpcKeys = set()
        reqKeys = set()
        for (v, p), (reqKey, _) in self.sentPrePrepares.items():
            if p <= tillSeqNo:
                tpcKeys.add((v, p))
                reqKeys.add(reqKey)
        for (v, p), (reqKey, _) in self.prePrepares.items():
            if p <= tillSeqNo:
                tpcKeys.add((v, p))
                reqKeys.add(reqKey)

        logger.debug("{} found {} 3 phase keys to clean".format(
            self, len(tpcKeys)))
        logger.debug("{} found {} request keys to clean".format(
            self, len(reqKeys)))

        for k in tpcKeys:
            self.sentPrePrepares.pop(k, None)
            self.prePrepares.pop(k, None)
            self.prepares.pop(k, None)
            self.commits.pop(k, None)
            if k in self.ordered:
                self.ordered.remove(k)

        for k in reqKeys:
            self.requests.pop(k, None)

    def processStashedMsgsForNewWaterMarks(self):
        while self.stashingWhileOutsideWaterMarks:
            item = self.stashingWhileOutsideWaterMarks.pop()
            logger.debug("{} processing stashed item {} after new stable "
                         "checkpoint".format(self, item))

            if isinstance(item, ReqDigest):
                self.doPrePrepare(item)
            elif isinstance(item, tuple) and len(tuple) == 2:
                self.dispatchThreePhaseMsg(*item)
            else:
                logger.error("{} cannot process {} "
                             "from stashingWhileOutsideWaterMarks".format(
                                 self, item))

    @property
    def firstCheckPoint(self) -> Tuple[Tuple[int, int], CheckpointState]:
        if not self.checkpoints:
            return None
        else:
            return self.checkpoints.peekitem(0)

    @property
    def lastCheckPoint(self) -> Tuple[Tuple[int, int], CheckpointState]:
        if not self.checkpoints:
            return None
        else:
            return self.checkpoints.peekitem(-1)

    def isPpSeqNoAcceptable(self, ppSeqNo: int):
        return self.h < ppSeqNo <= self.H

    def addToOrdered(self, viewNo: int, ppSeqNo: int):
        self.ordered.add((viewNo, ppSeqNo))

    def enqueuePrePrepare(self, request: PrePrepare, sender: str):
        logger.debug(
            "Queueing pre-prepares due to unavailability of finalised "
            "Request. Request {} from {}".format(request, sender))
        key = (request.identifier, request.reqId)
        if key not in self.prePreparesPendingReqDigest:
            self.prePreparesPendingReqDigest[key] = []
        self.prePreparesPendingReqDigest[key].append((request, sender))

    def dequeuePrePrepare(self, identifier: int, reqId: int):
        key = (identifier, reqId)
        if key in self.prePreparesPendingReqDigest:
            pps = self.prePreparesPendingReqDigest[key]
            for (pp, sender) in pps:
                logger.debug("{} popping stashed PRE-PREPARE{}".format(
                    self, key))
                if pp.digest == self.requests.digest(key):
                    self.prePreparesPendingReqDigest.pop(key)
                    self.processPrePrepare(pp, sender)
                    logger.debug(
                        "{} processed {} PRE-PREPAREs waiting for finalised "
                        "request for identifier {} and reqId {}".format(
                            self, pp, identifier, reqId))
                    break

    def enqueuePrepare(self, request: Prepare, sender: str):
        logger.debug("Queueing prepares due to unavailability of PRE-PREPARE. "
                     "Request {} from {}".format(request, sender))
        key = (request.viewNo, request.ppSeqNo)
        if key not in self.preparesWaitingForPrePrepare:
            self.preparesWaitingForPrePrepare[key] = deque()
        self.preparesWaitingForPrePrepare[key].append((request, sender))

    def dequeuePrepares(self, viewNo: int, ppSeqNo: int):
        key = (viewNo, ppSeqNo)
        if key in self.preparesWaitingForPrePrepare:
            i = 0
            # Keys of pending prepares that will be processed below
            while self.preparesWaitingForPrePrepare[key]:
                prepare, sender = self.preparesWaitingForPrePrepare[
                    key].popleft()
                logger.debug("{} popping stashed PREPARE{}".format(self, key))
                self.processPrepare(prepare, sender)
                i += 1
            self.preparesWaitingForPrePrepare.pop(key)
            logger.debug("{} processed {} PREPAREs waiting for PRE-PREPARE for"
                         " view no {} and seq no {}".format(
                             self, i, viewNo, ppSeqNo))

    def enqueueCommit(self, request: Commit, sender: str):
        logger.debug("Queueing commit due to unavailability of PREPARE. "
                     "Request {} from {}".format(request, sender))
        key = (request.viewNo, request.ppSeqNo)
        if key not in self.commitsWaitingForPrepare:
            self.commitsWaitingForPrepare[key] = deque()
        self.commitsWaitingForPrepare[key].append((request, sender))

    def dequeueCommits(self, viewNo: int, ppSeqNo: int):
        key = (viewNo, ppSeqNo)
        if key in self.commitsWaitingForPrepare:
            i = 0
            # Keys of pending prepares that will be processed below
            while self.commitsWaitingForPrepare[key]:
                commit, sender = self.commitsWaitingForPrepare[key].popleft()
                logger.debug("{} popping stashed COMMIT{}".format(self, key))
                self.processCommit(commit, sender)
                i += 1
            self.commitsWaitingForPrepare.pop(key)
            logger.debug("{} processed {} COMMITs waiting for PREPARE for"
                         " view no {} and seq no {}".format(
                             self, i, viewNo, ppSeqNo))

    def getDigestFor3PhaseKey(self, key: ThreePhaseKey) -> Optional[str]:
        reqKey = self.getReqKeyFrom3PhaseKey(key)
        digest = self.requests.digest(reqKey)
        if not digest:
            logger.debug("{} could not find digest in sent or received "
                         "PRE-PREPAREs or PREPAREs for 3 phase key {} and req "
                         "key {}".format(self, key, reqKey))
            return None
        else:
            return digest

    def getReqKeyFrom3PhaseKey(self, key: ThreePhaseKey):
        reqKey = None
        if key in self.sentPrePrepares:
            reqKey = self.sentPrePrepares[key][0]
        elif key in self.prePrepares:
            reqKey = self.prePrepares[key][0]
        elif key in self.prepares:
            reqKey = self.prepares[key][0]
        else:
            logger.debug(
                "Could not find request key for 3 phase key {}".format(key))
        return reqKey

    @property
    def threePhaseState(self):
        # TODO: This method is incomplete
        # Gets the current stable and unstable checkpoints and creates digest
        # of unstable checkpoints
        if self.checkpoints:
            pass
        else:
            state = []
        return ThreePCState(self.instId, state)

    def process3PhaseState(self, msg: ThreePCState, sender: str):
        # TODO: This is not complete
        pass

    def send(self, msg, stat=None) -> None:
        """
        Send a message to the node on which this replica resides.

        :param msg: the message to send
        """
        logger.display("{} sending {}".format(self, msg.__class__.__name__),
                       extra={"cli": True})
        logger.trace("{} sending {}".format(self, msg))
        if stat:
            self.stats.inc(stat)
        self.outBox.append(msg)
Пример #46
0
class Superplot():
    """
Self-contained plotting class that runs in its own process.
Plotting functionality (reset the graph, .. ?) can be controlled
by issuing message-based commands using a multiprocessing Pipe

    """
    def __init__(self,name,plottype=PlotType.indexed):
        self.name = name
        self.plottype = plottype
        self._clear()

    def _clear(self):
        # Process-local buffers used to host the displayed data
        if self.plottype == PlotType.linear:
            self.set = True
            self.x = []
            self.y = []
        else:
            self.xy = SortedDict()
            # TODO : use this optimization, but for now raises issue
            # Can't pickle dict_key views ??
            #self.x = self.xy.keys()
            #self.y = self.xy.values()
            self.set = False

    def start(self):
        # The queue that will be used to transfer data from the main process
        # to the plot
        self.q = Queue()
        main_pipe, self.in_process_pipe = Pipe()
        self.p = Process(target=self.run)
        self.p.start()
        # Return a handle to the data queue and the control pipe
        return self.q, main_pipe

    def join(self):
        self.p.join()

    def _update(self):
        # Empty data queue and process received data
        while not self.q.empty():
            item = self.q.get()
            if self.plottype == PlotType.linear:
                self.x.append(item[0])
                self.y.append(item[1])
            else:
                # Seems pretty slow,
                # TODO : Profile
                # TODO : Eventually, need to find high performance alternative. Maybe numpy based
                self.xy[item[0]] = item[1]

        # Initialize view on data dictionnary only once for increased performance
        if not self.set:
            self.set = True
            self.x = self.xy.keys()
            self.y = self.xy.values()

        # Refresh plot data
        self.curve.setData(self.x,self.y)

        try:
            if self.in_process_pipe.poll():
                msg = self.in_process_pipe.recv()
                self._process_msg(msg)
        except:
            # If the polling failed, then the application most likely shut down
            # So close the window and terminate as well
            self.app.quit()

    def _process_msg(self, msg):
        if msg == "exit":
            # TODO : Remove this line ? Redundant with send after app.exec_() ?
            self.in_process_pipe.send("closing")
            self.app.quit()
        elif msg == "clear":
            self._clear()

    def run(self):
        self.app = QtGui.QApplication([])
        win = pg.GraphicsWindow(title="Basic plotting examples")
        win.resize(1000,600)
        win.setWindowTitle('pyqtgraph example: Plotting')
        plot = win.addPlot(title=self.name)
        self.curve = plot.plot(pen='y')

        timer = QtCore.QTimer()
        timer.timeout.connect(self._update)
        timer.start(50)

        self.app.exec_()
        try:
            self.in_process_pipe.send("closing")
        except:
            pass
Пример #47
0
def doMarginal(net,argies):
	prob = 0
	for i in argies:
		# First argument Variable
		var = i[0]
		# First argument state
		state = i[1]
		# Base Cases for Marginal Recursion
		if var == 'p':
			if state == 'true':
				prob += net[0].p['pL']
			else:
				prob += net[0].p['pH']
		if var == 's':
			if state == 'true':
				prob += net[1].p['sT']
			else:
				prob += net[1].p['sF']
		if var == 'c':
			pars = net[2].parents
			a = SortedDict(net[2].p)
			b = SortedDict(pars[0].p)
			c = SortedDict(pars[1].p)
			# MARGINAL LOGIC
			if state == 'true':
				for combos in a.keys():
					for combo1 in b.keys():
						for combo2 in c.keys():
							if combo1 in combos and combo2 in combos:
								prob += net[2].p[combos]*pars[0].p[combo1]*pars[1].p[combo2]
			elif state == 'false':
				prob = 1 - doMarginal(net,[('c', 'true')])
			
			# CONDITIONALS FROM CLASS 
			elif state == 'sT':
				for combos in a.keys():
                                        for combo1 in b.keys():
                                                for combo2 in c.keys():
                                                        if combo2 == 'sT' and combo1 in combos and combo2 in combos:
                                                                breakP()
								prob += net[2].p[combos]*pars[0].p[combo1]*pars[1].p[combo2]	
			
			elif state == 'sF':
				for combos in a.keys():
                                        for combo1 in b.keys():
                                                for combo2 in c.keys():
                                                        if combo2 == 'sF' and combo1 in combos and combo2 in combos:
                                                                breakP()
								prob += net[2].p[combos]*pars[0].p[combo1]*pars[1].p[combo2]	
			elif state == 'pH':
				for combos in a.keys():
                                        for combo1 in b.keys():
                                                for combo2 in c.keys():
                                                        if combo1 == 'pH' and combo1 in combos and combo2 in combos:
                                                                breakP()
								prob += net[2].p[combos]*pars[0].p[combo1]*pars[1].p[combo2]	
			
			elif state == 'pL':
				for combos in a.keys():
                                        for combo1 in b.keys():
                                                for combo2 in c.keys():
                                                        if combo1 == 'pL' and combo1 in combos and combo2 in combos:
                                                                breakP()
								prob += net[2].p[combos]*pars[0].p[combo1]*pars[1].p[combo2]	
			else:
				raise Exception('CANCER CANT BE DETERMINED FROM XRAY OR DYSPNOEA BECAUSE THEY DEPEND ON CANCER OUTPUT')
		if var == 'x':
			pars = net[3].parents
			a = SortedDict(net[3].p)
			b = SortedDict(pars[0].p)
			if state == 'true':
				for combos in a.keys():
					for combo1 in b.keys():
						if combos == 'XcT':
							prob += net[3].p[combos]*pars[0].p[combo1]
	

			#if state == 'true':
			#	prob += net[2].p['CpHsT']*net[0].['pH']*net[1].['sT']
			#	prob += net[2].p['CpLsT']*net[0].['pL']*net[1].['sT']
			#	prob += net[2].p['Cp


	return prob
Пример #48
0
class OrderedDict(dict):
    """Dictionary that remembers insertion order and is numerically indexable.

    Keys are numerically indexable using dict views. For example::

        >>> ordered_dict = OrderedDict.fromkeys('abcde')
        >>> keys = ordered_dict.keys()
        >>> keys[0]
        'a'
        >>> keys[-2:]
        ['d', 'e']

    The dict views support the sequence abstract base class.

    """

    # pylint: disable=super-init-not-called
    def __init__(self, *args, **kwargs):
        self._keys = {}
        self._nums = SortedDict()
        self._keys_view = self._nums.keys()
        self._count = count()
        self.update(*args, **kwargs)

    def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
        "``ordered_dict[key] = value``"
        if key not in self:
            num = next(self._count)
            self._keys[key] = num
            self._nums[num] = key
        dict_setitem(self, key, value)

    def __delitem__(self, key, dict_delitem=dict.__delitem__):
        "``del ordered_dict[key]``"
        dict_delitem(self, key)
        num = self._keys.pop(key)
        del self._nums[num]

    def __iter__(self):
        "``iter(ordered_dict)``"
        return iter(self._nums.values())

    def __reversed__(self):
        "``reversed(ordered_dict)``"
        nums = self._nums
        for key in reversed(nums):
            yield nums[key]

    def clear(self, dict_clear=dict.clear):
        "Remove all items from mapping."
        dict_clear(self)
        self._keys.clear()
        self._nums.clear()

    def popitem(self, last=True):
        """Remove and return (key, value) item pair.

        Pairs are returned in LIFO order if last is True or FIFO order if
        False.

        """
        index = -1 if last else 0
        num = self._keys_view[index]
        key = self._nums[num]
        value = self.pop(key)
        return key, value

    update = __update = co.MutableMapping.update

    def keys(self):
        "Return set-like and sequence-like view of mapping keys."
        return KeysView(self)

    def items(self):
        "Return set-like and sequence-like view of mapping items."
        return ItemsView(self)

    def values(self):
        "Return set-like and sequence-like view of mapping values."
        return ValuesView(self)

    def pop(self, key, default=NONE):
        """Remove given key and return corresponding value.

        If key is not found, default is returned if given, otherwise raise
        KeyError.

        """
        if key in self:
            value = self[key]
            del self[key]
            return value
        elif default is NONE:
            raise KeyError(key)
        else:
            return default

    def setdefault(self, key, default=None):
        """Return ``mapping.get(key, default)``, also set ``mapping[key] = default`` if
        key not in mapping.

        """
        if key in self:
            return self[key]
        self[key] = default
        return default

    @recursive_repr()
    def __repr__(self):
        "Text representation of mapping."
        return '%s(%r)' % (self.__class__.__name__, list(self.items()))

    __str__ = __repr__

    def __reduce__(self):
        "Support for pickling serialization."
        return (self.__class__, (list(self.items()), ))

    def copy(self):
        "Return shallow copy of mapping."
        return self.__class__(self)

    @classmethod
    def fromkeys(cls, iterable, value=None):
        """Return new mapping with keys from iterable.

        If not specified, value defaults to None.

        """
        return cls((key, value) for key in iterable)

    def __eq__(self, other):
        "Test self and other mapping for equality."
        if isinstance(other, OrderedDict):
            return dict.__eq__(self, other) and all(map(eq, self, other))
        return dict.__eq__(self, other)

    __ne__ = co.MutableMapping.__ne__

    def _check(self):
        "Check consistency of internal member variables."
        # pylint: disable=protected-access
        keys = self._keys
        nums = self._nums

        for key, value in keys.items():
            assert nums[value] == key

        nums._check()
Пример #49
0
class IntervalTree(collections.MutableSet):
    """
    A binary lookup tree of intervals.
    The intervals contained in the tree are represented using ``Interval(a, b, data)`` objects.
    Each such object represents a half-open interval ``[a, b)`` with optional data.
    
    Examples:
    ---------
    
    Initialize a blank tree::
    
        >>> tree = IntervalTree()
        >>> tree
        IntervalTree()
    
    Initialize a tree from an iterable set of Intervals in O(n * log n)::
    
        >>> tree = IntervalTree([Interval(-10, 10), Interval(-20.0, -10.0)])
        >>> tree
        IntervalTree([Interval(-20.0, -10.0), Interval(-10, 10)])
        >>> len(tree)
        2
    
    Note that this is a set, i.e. repeated intervals are ignored. However,
    Intervals with different data fields are regarded as different::
    
        >>> tree = IntervalTree([Interval(-10, 10), Interval(-10, 10), Interval(-10, 10, "x")])
        >>> tree
        IntervalTree([Interval(-10, 10), Interval(-10, 10, 'x')])
        >>> len(tree)
        2
    
    Insertions::
        >>> tree = IntervalTree()
        >>> tree[0:1] = "data"
        >>> tree.add(Interval(10, 20))
        >>> tree.addi(19.9, 20)
        >>> tree
        IntervalTree([Interval(0, 1, 'data'), Interval(10, 20), Interval(19.9, 20)])
        >>> tree.update([Interval(19.9, 20.1), Interval(20.1, 30)])
        >>> len(tree)
        5

        Inserting the same Interval twice does nothing::
            >>> tree = IntervalTree()
            >>> tree[-10:20] = "arbitrary data"
            >>> tree[-10:20] = None  # Note that this is also an insertion
            >>> tree
            IntervalTree([Interval(-10, 20), Interval(-10, 20, 'arbitrary data')])
            >>> tree[-10:20] = None  # This won't change anything
            >>> tree[-10:20] = "arbitrary data" # Neither will this
            >>> len(tree)
            2

    Deletions::
        >>> tree = IntervalTree(Interval(b, e) for b, e in [(-10, 10), (-20, -10), (10, 20)])
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(-10, 10), Interval(10, 20)])
        >>> tree.remove(Interval(-10, 10))
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(10, 20)])
        >>> tree.remove(Interval(-10, 10))
        Traceback (most recent call last):
        ...
        ValueError
        >>> tree.discard(Interval(-10, 10))  # Same as remove, but no exception on failure
        >>> tree
        IntervalTree([Interval(-20, -10), Interval(10, 20)])
        
    Delete intervals, overlapping a given point::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.1)
        >>> tree
        IntervalTree([Interval(-1.1, 1.1)])
        
    Delete intervals, overlapping an interval::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_overlap(0, 0.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.7, 1.8)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.6, 1.6)  # Null interval does nothing
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_overlap(1.6, 1.5)  # Ditto
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        
    Delete intervals, enveloped in the range::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> tree.remove_envelop(-1.0, 1.5)
        >>> tree
        IntervalTree([Interval(-1.1, 1.1), Interval(0.5, 1.7)])
        >>> tree.remove_envelop(-1.1, 1.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_envelop(0.5, 1.5)
        >>> tree
        IntervalTree([Interval(0.5, 1.7)])
        >>> tree.remove_envelop(0.5, 1.7)
        >>> tree
        IntervalTree()
        
    Point/interval overlap queries::
    
        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> assert tree[-1.1]         == set([Interval(-1.1, 1.1)])
        >>> assert tree.search(1.1)   == set([Interval(-0.5, 1.5), Interval(0.5, 1.7)])   # Same as tree[1.1]
        >>> assert tree[-0.5:0.5]     == set([Interval(-0.5, 1.5), Interval(-1.1, 1.1)])  # Interval overlap query
        >>> assert tree.search(1.5, 1.5) == set()                                         # Same as tree[1.5:1.5]
        >>> assert tree.search(1.5) == set([Interval(0.5, 1.7)])                          # Same as tree[1.5]

        >>> assert tree.search(1.7, 1.8) == set()

    Envelop queries::
    
        >>> assert tree.search(-0.5, 0.5, strict=True) == set()
        >>> assert tree.search(-0.4, 1.7, strict=True) == set([Interval(0.5, 1.7)])
        
    Membership queries::

        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> Interval(-0.5, 0.5) in tree
        False
        >>> Interval(-1.1, 1.1) in tree
        True
        >>> Interval(-1.1, 1.1, "x") in tree
        False
        >>> tree.overlaps(-1.1)
        True
        >>> tree.overlaps(1.7)
        False
        >>> tree.overlaps(1.7, 1.8)
        False
        >>> tree.overlaps(-1.2, -1.1)
        False
        >>> tree.overlaps(-1.2, -1.0)
        True
    
    Sizing::

        >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)])
        >>> len(tree)
        3
        >>> tree.is_empty()
        False
        >>> IntervalTree().is_empty()
        True
        >>> not tree
        False
        >>> not IntervalTree()
        True
        >>> print(tree.begin())    # using print() because of floats in Python 2.6
        -1.1
        >>> print(tree.end())      # ditto
        1.7
        
    Iteration::

        >>> tree = IntervalTree([Interval(-11, 11), Interval(-5, 15), Interval(5, 17)])
        >>> [iv.begin for iv in sorted(tree)]
        [-11, -5, 5]
        >>> assert tree.items() == set([Interval(-5, 15), Interval(-11, 11), Interval(5, 17)])

    Copy- and typecasting, pickling::
    
        >>> tree0 = IntervalTree([Interval(0, 1, "x"), Interval(1, 2, ["x"])])
        >>> tree1 = IntervalTree(tree0)  # Shares Interval objects
        >>> tree2 = tree0.copy()         # Shallow copy (same as above, as Intervals are singletons)
        >>> import pickle
        >>> tree3 = pickle.loads(pickle.dumps(tree0))  # Deep copy
        >>> list(tree0[1])[0].data[0] = "y"  # affects shallow copies, but not deep copies
        >>> tree0
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree1
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree2
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])])
        >>> tree3
        IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['x'])])
        
    Equality testing::
    
        >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1)])
        True
        >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1, "x")])
        False
    """
    @classmethod
    def from_tuples(cls, tups):
        """
        Create a new IntervalTree from an iterable of 2- or 3-tuples,
         where the tuple lists begin, end, and optionally data.
        """
        ivs = [Interval(*t) for t in tups]
        return IntervalTree(ivs)

    def __init__(self, intervals=None):
        """
        Set up a tree. If intervals is provided, add all the intervals 
        to the tree.
        
        Completes in O(n*log n) time.
        """
        intervals = set(intervals) if intervals is not None else set()
        for iv in intervals:
            if iv.is_null():
                raise ValueError(
                    "IntervalTree: Null Interval objects not allowed in IntervalTree:"
                    " {0}".format(iv)
                )
        self.all_intervals = intervals
        self.top_node = Node.from_intervals(self.all_intervals)
        self.boundary_table = SortedDict()
        for iv in self.all_intervals:
            self._add_boundaries(iv)

    def copy(self):
        """
        Construct a new IntervalTree using shallow copies of the 
        intervals in the source tree.
        
        Completes in O(n*log n) time.
        :rtype: IntervalTree
        """
        return IntervalTree(iv.copy() for iv in self)
    
    def _add_boundaries(self, interval):
        """
        Records the boundaries of the interval in the boundary table.
        """
        begin = interval.begin
        end = interval.end
        if begin in self.boundary_table: 
            self.boundary_table[begin] += 1
        else:
            self.boundary_table[begin] = 1
        
        if end in self.boundary_table:
            self.boundary_table[end] += 1
        else:
            self.boundary_table[end] = 1
    
    def _remove_boundaries(self, interval):
        """
        Removes the boundaries of the interval from the boundary table.
        """
        begin = interval.begin
        end = interval.end
        if self.boundary_table[begin] == 1:
            del self.boundary_table[begin]
        else:
            self.boundary_table[begin] -= 1
        
        if self.boundary_table[end] == 1:
            del self.boundary_table[end]
        else:
            self.boundary_table[end] -= 1
    
    def add(self, interval):
        """
        Adds an interval to the tree, if not already present.
        
        Completes in O(log n) time.
        """
        if interval in self: 
            return

        if interval.is_null():
            raise ValueError(
                "IntervalTree: Null Interval objects not allowed in IntervalTree:"
                " {0}".format(interval)
            )

        if not self.top_node:
            self.top_node = Node.from_interval(interval)
        else:
            self.top_node = self.top_node.add(interval)
        self.all_intervals.add(interval)
        self._add_boundaries(interval)
    append = add
    
    def addi(self, begin, end, data=None):
        """
        Shortcut for add(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.add(Interval(begin, end, data))
    appendi = addi
    
    def update(self, intervals):
        """
        Given an iterable of intervals, add them to the tree.
        
        Completes in O(m*log(n+m), where m = number of intervals to 
        add.
        """
        for iv in intervals:
            self.add(iv)

    def extend(self, intervals):
        """
        Deprecated: Replaced by update().
        """
        warn("IntervalTree.extend() has been deprecated. Consider using update() instead", DeprecationWarning)
        self.update(intervals)

    def remove(self, interval):
        """
        Removes an interval from the tree, if present. If not, raises 
        ValueError.
        
        Completes in O(log n) time.
        """
        #self.verify()
        if interval not in self:
            #print(self.all_intervals)
            raise ValueError
        self.top_node = self.top_node.remove(interval)
        self.all_intervals.remove(interval)
        self._remove_boundaries(interval)
        #self.verify()
    
    def removei(self, begin, end, data=None):
        """
        Shortcut for remove(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.remove(Interval(begin, end, data))
    
    def discard(self, interval):
        """
        Removes an interval from the tree, if present. If not, does 
        nothing.
        
        Completes in O(log n) time.
        """
        if interval not in self:
            return
        self.all_intervals.discard(interval)
        self.top_node = self.top_node.discard(interval)
        self._remove_boundaries(interval)
    
    def discardi(self, begin, end, data=None):
        """
        Shortcut for discard(Interval(begin, end, data)).
        
        Completes in O(log n) time.
        """
        return self.discard(Interval(begin, end, data))

    def difference(self, other):
        """
        Returns a new tree, comprising all intervals in self but not
        in other.
        """
        ivs = set()
        for iv in self:
            if iv not in other:
                ivs.add(iv)
        return IntervalTree(ivs)

    def difference_update(self, other):
        """
        Removes all intervals in other from self.
        """
        for iv in other:
            self.discard(iv)

    def union(self, other):
        """
        Returns a new tree, comprising all intervals from self
        and other.
        """
        return IntervalTree(set(self).union(other))

    def intersection(self, other):
        """
        Returns a new tree of all intervals common to both self and
        other.
        """
        ivs = set()
        shorter, longer = sorted([self, other], key=len)
        for iv in shorter:
            if iv in longer:
                ivs.add(iv)
        return IntervalTree(ivs)

    def intersection_update(self, other):
        """
        Removes intervals from self unless they also exist in other.
        """
        for iv in self:
            if iv not in other:
                self.remove(iv)

    def symmetric_difference(self, other):
        """
        Return a tree with elements only in self or other but not
        both.
        """
        if not isinstance(other, set): other = set(other)
        me = set(self)
        ivs = me - other + (other - me)
        return IntervalTree(ivs)

    def symmetric_difference_update(self, other):
        """
        Throws out all intervals except those only in self or other,
        not both.
        """
        other = set(other)
        for iv in self:
            if iv in other:
                self.remove(iv)
                other.remove(iv)
        self.update(other)

    def remove_overlap(self, begin, end=None):
        """
        Removes all intervals overlapping the given point or range.
        
        Completes in O((r+m)*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * r = size of the search range (this is 1 for a point)
        """
        hitlist = self.search(begin, end)
        for iv in hitlist: 
            self.remove(iv)

    def remove_envelop(self, begin, end):
        """
        Removes all intervals completely enveloped in the given range.
        
        Completes in O((r+m)*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * r = size of the search range (this is 1 for a point)
        """
        hitlist = self.search(begin, end, strict=True)
        for iv in hitlist:
            self.remove(iv)

    def chop(self, begin, end, datafunc=None):
        """
        Like remove_envelop(), but trims back Intervals hanging into
        the chopped area so that nothing overlaps.
        """
        insertions = set()
        begin_hits = [iv for iv in self[begin] if iv.begin < begin]
        end_hits = [iv for iv in self[end] if iv.end > end]

        if datafunc:
            for iv in begin_hits:
                insertions.add(Interval(iv.begin, begin, datafunc(iv, True)))
            for iv in end_hits:
                insertions.add(Interval(end, iv.end, datafunc(iv, False)))
        else:
            for iv in begin_hits:
                insertions.add(Interval(iv.begin, begin, iv.data))
            for iv in end_hits:
                insertions.add(Interval(end, iv.end, iv.data))

        self.remove_envelop(begin, end)
        self.difference_update(begin_hits)
        self.difference_update(end_hits)
        self.update(insertions)

    def slice(self, point, datafunc=None):
        """
        Split Intervals that overlap point into two new Intervals. if
        specified, uses datafunc(interval, islower=True/False) to
        set the data field of the new Intervals.
        :param point: where to slice
        :param datafunc(interval, isupper): callable returning a new
        value for the interval's data field
        """
        hitlist = set(iv for iv in self[point] if iv.begin < point)
        insertions = set()
        if datafunc:
            for iv in hitlist:
                insertions.add(Interval(iv.begin, point, datafunc(iv, True)))
                insertions.add(Interval(point, iv.end, datafunc(iv, False)))
        else:
            for iv in hitlist:
                insertions.add(Interval(iv.begin, point, iv.data))
                insertions.add(Interval(point, iv.end, iv.data))
        self.difference_update(hitlist)
        self.update(insertions)

    def clear(self):
        """
        Empties the tree.

        Completes in O(1) tine.
        """
        self.__init__()

    def find_nested(self):
        """
        Returns a dictionary mapping parent intervals to sets of 
        intervals overlapped by and contained in the parent.
        
        Completes in O(n^2) time.
        :rtype: dict of [Interval, set of Interval]
        """
        result = {}
        
        def add_if_nested():
            if parent.contains_interval(child):
                if parent not in result:
                    result[parent] = set()
                result[parent].add(child)
                
        long_ivs = sorted(self.all_intervals, key=Interval.length, reverse=True)
        for i, parent in enumerate(long_ivs):
            for child in long_ivs[i + 1:]:
                add_if_nested()
        return result
    
    def overlaps(self, begin, end=None):
        """
        Returns whether some interval in the tree overlaps the given
        point or range.
        
        Completes in O(r*log n) time, where r is the size of the
        search range.
        :rtype: bool
        """
        if end is not None:
            return self.overlaps_range(begin, end)
        elif isinstance(begin, Number):
            return self.overlaps_point(begin)
        else:
            return self.overlaps_range(begin.begin, begin.end)
    
    def overlaps_point(self, p):
        """
        Returns whether some interval in the tree overlaps p.
        
        Completes in O(log n) time.
        :rtype: bool
        """
        if self.is_empty():
            return False
        return bool(self.top_node.contains_point(p))
    
    def overlaps_range(self, begin, end):
        """
        Returns whether some interval in the tree overlaps the given
        range. Returns False if given a null interval over which to
        test.
        
        Completes in O(r*log n) time, where r is the range length and n
        is the table size.
        :rtype: bool
        """
        if self.is_empty():
            return False
        elif begin >= end:
            return False
        elif self.overlaps_point(begin):
            return True
        return any(
            self.overlaps_point(bound) 
            for bound in self.boundary_table 
            if begin < bound < end
        )
    
    def split_overlaps(self):
        """
        Finds all intervals with overlapping ranges and splits them
        along the range boundaries.
        
        Completes in worst-case O(n^2*log n) time (many interval 
        boundaries are inside many intervals), best-case O(n*log n)
        time (small number of overlaps << n per interval).
        """
        if not self:
            return
        if len(self.boundary_table) == 2:
            return

        bounds = sorted(self.boundary_table)  # get bound locations

        new_ivs = set()
        for lbound, ubound in zip(bounds[:-1], bounds[1:]):
            for iv in self[lbound]:
                new_ivs.add(Interval(lbound, ubound, iv.data))

        self.__init__(new_ivs)

    def merge_overlaps(self, data_reducer=None, data_initializer=None):
        """
        Finds all intervals with overlapping ranges and merges them
        into a single interval. If provided, uses data_reducer and
        data_initializer with similar semantics to Python's built-in
        reduce(reducer_func[, initializer]), as follows:

        If data_reducer is set to a function, combines the data
        fields of the Intervals with
            current_reduced_data = data_reducer(current_reduced_data, new_data)
        If data_reducer is None, the merged Interval's data
        field will be set to None, ignoring all the data fields
        of the merged Intervals.

        On encountering the first Interval to merge, if
        data_initializer is None (default), uses the first
        Interval's data field as the first value for
        current_reduced_data. If data_initializer is not None,
        current_reduced_data is set to a shallow copy of
        data_initiazer created with
            copy.copy(data_initializer).

        Completes in O(n*logn).
        """
        if not self:
            return

        sorted_intervals = sorted(self.all_intervals)  # get sorted intervals
        merged = []
        # use mutable object to allow new_series() to modify it
        current_reduced = [None]
        higher = None  # iterating variable, which new_series() needs access to

        def new_series():
            if data_initializer is None:
                current_reduced[0] = higher.data
                merged.append(higher)
                return
            else:  # data_initializer is not None
                current_reduced[0] = copy(data_initializer)
                current_reduced[0] = data_reducer(current_reduced[0], higher.data)
                merged.append(Interval(higher.begin, higher.end, current_reduced[0]))

        for higher in sorted_intervals:
            if merged:  # series already begun
                lower = merged[-1]
                if higher.begin <= lower.end:  # should merge
                    upper_bound = max(lower.end, higher.end)
                    if data_reducer is not None:
                        current_reduced[0] = data_reducer(current_reduced[0], higher.data)
                    else:  # annihilate the data, since we don't know how to merge it
                        current_reduced[0] = None
                    merged[-1] = Interval(lower.begin, upper_bound, current_reduced[0])
                else:
                    new_series()
            else:  # not merged; is first of Intervals to merge
                new_series()

        self.__init__(merged)

    def merge_equals(self, data_reducer=None, data_initializer=None):
        """
        Finds all intervals with equal ranges and merges them
        into a single interval. If provided, uses data_reducer and
        data_initializer with similar semantics to Python's built-in
        reduce(reducer_func[, initializer]), as follows:

        If data_reducer is set to a function, combines the data
        fields of the Intervals with
            current_reduced_data = data_reducer(current_reduced_data, new_data)
        If data_reducer is None, the merged Interval's data
        field will be set to None, ignoring all the data fields
        of the merged Intervals.

        On encountering the first Interval to merge, if
        data_initializer is None (default), uses the first
        Interval's data field as the first value for
        current_reduced_data. If data_initializer is not None,
        current_reduced_data is set to a shallow copy of
        data_initiazer created with
            copy.copy(data_initializer).

        Completes in O(n*logn).
        """
        if not self:
            return

        sorted_intervals = sorted(self.all_intervals)  # get sorted intervals
        merged = []
        # use mutable object to allow new_series() to modify it
        current_reduced = [None]
        higher = None  # iterating variable, which new_series() needs access to

        def new_series():
            if data_initializer is None:
                current_reduced[0] = higher.data
                merged.append(higher)
                return
            else:  # data_initializer is not None
                current_reduced[0] = copy(data_initializer)
                current_reduced[0] = data_reducer(current_reduced[0], higher.data)
                merged.append(Interval(higher.begin, higher.end, current_reduced[0]))

        for higher in sorted_intervals:
            if merged:  # series already begun
                lower = merged[-1]
                if higher.range_matches(lower):  # should merge
                    upper_bound = max(lower.end, higher.end)
                    if data_reducer is not None:
                        current_reduced[0] = data_reducer(current_reduced[0], higher.data)
                    else:  # annihilate the data, since we don't know how to merge it
                        current_reduced[0] = None
                    merged[-1] = Interval(lower.begin, upper_bound, current_reduced[0])
                else:
                    new_series()
            else:  # not merged; is first of Intervals to merge
                new_series()

        self.__init__(merged)

    def items(self):
        """
        Constructs and returns a set of all intervals in the tree. 
        
        Completes in O(n) time.
        :rtype: set of Interval
        """
        return set(self.all_intervals)
    
    def is_empty(self):
        """
        Returns whether the tree is empty.
        
        Completes in O(1) time.
        :rtype: bool
        """
        return 0 == len(self)

    def search(self, begin, end=None, strict=False):
        """
        Returns a set of all intervals overlapping the given range. Or,
        if strict is True, returns the set of all intervals fully
        contained in the range [begin, end].
        
        Completes in O(m + k*log n) time, where:
          * n = size of the tree
          * m = number of matches
          * k = size of the search range (this is 1 for a point)
        :rtype: set of Interval
        """
        root = self.top_node
        if not root:
            return set()
        if end is None:
            try:
                iv = begin
                return self.search(iv.begin, iv.end, strict=strict)
            except:
                return root.search_point(begin, set())
        elif begin >= end:
            return set()
        else:
            result = root.search_point(begin, set())

            boundary_table = self.boundary_table
            bound_begin = boundary_table.bisect_left(begin)
            bound_end = boundary_table.bisect_left(end)  # exclude final end bound
            result.update(root.search_overlap(
                # slice notation is slightly slower
                boundary_table.iloc[index] for index in xrange(bound_begin, bound_end)
            ))

            # TODO: improve strict search to use node info instead of less-efficient filtering
            if strict:
                result = set(
                    iv for iv in result
                    if iv.begin >= begin and iv.end <= end
                )
            return result
    
    def begin(self):
        """
        Returns the lower bound of the first interval in the tree.
        
        Completes in O(n) time.
        """
        if not self.boundary_table:
            return 0
        return self.boundary_table.iloc[0]
    
    def end(self):
        """
        Returns the upper bound of the last interval in the tree.
        
        Completes in O(n) time.
        """
        if not self.boundary_table:
            return 0
        return self.boundary_table.iloc[-1]

    def range(self):
        """
        Returns a minimum-spanning Interval that encloses all the
        members of this IntervalTree. If the tree is empty, returns
        null Interval.
        :rtype: Interval
        """
        return Interval(self.begin(), self.end())

    def span(self):
        """
        Returns the length of the minimum-spanning Interval that
        encloses all the members of this IntervalTree. If the tree
        is empty, return 0.
        """
        if not self:
            return 0
        return self.end() - self.begin()

    def print_structure(self, tostring=False):
        """
        ## FOR DEBUGGING ONLY ##
        Pretty-prints the structure of the tree. 
        If tostring is true, prints nothing and returns a string.
        :rtype: None or str
        """
        if self.top_node:
            return self.top_node.print_structure(tostring=tostring)
        else:
            result = "<empty IntervalTree>"
            if not tostring:
                print(result)
            else:
                return result
        
    def verify(self):
        """
        ## FOR DEBUGGING ONLY ##
        Checks the table to ensure that the invariants are held.
        """
        if self.all_intervals:
            ## top_node.all_children() == self.all_intervals
            try:
                assert self.top_node.all_children() == self.all_intervals
            except AssertionError as e:
                print(
                    'Error: the tree and the membership set are out of sync!'
                )
                tivs = set(self.top_node.all_children())
                print('top_node.all_children() - all_intervals:')
                try:
                    pprint
                except NameError:
                    from pprint import pprint
                pprint(tivs - self.all_intervals)
                print('all_intervals - top_node.all_children():')
                pprint(self.all_intervals - tivs)
                raise e

            ## All members are Intervals
            for iv in self:
                assert isinstance(iv, Interval), (
                    "Error: Only Interval objects allowed in IntervalTree:"
                    " {0}".format(iv)
                )

            ## No null intervals
            for iv in self:
                assert not iv.is_null(), (
                    "Error: Null Interval objects not allowed in IntervalTree:"
                    " {0}".format(iv)
                )

            ## Reconstruct boundary_table
            bound_check = {}
            for iv in self:
                if iv.begin in bound_check:
                    bound_check[iv.begin] += 1
                else:
                    bound_check[iv.begin] = 1
                if iv.end in bound_check:
                    bound_check[iv.end] += 1
                else:
                    bound_check[iv.end] = 1

            ## Reconstructed boundary table (bound_check) ==? boundary_table
            assert set(self.boundary_table.keys()) == set(bound_check.keys()),\
                'Error: boundary_table is out of sync with ' \
                'the intervals in the tree!'

            # For efficiency reasons this should be iteritems in Py2, but we
            # don't care much for efficiency in debug methods anyway.
            for key, val in self.boundary_table.items():
                assert bound_check[key] == val, \
                    'Error: boundary_table[{0}] should be {1},' \
                    ' but is {2}!'.format(
                        key, bound_check[key], val)

            ## Internal tree structure
            self.top_node.verify(set())
        else:
            ## Verify empty tree
            assert not self.boundary_table, \
                "Error: boundary table should be empty!"
            assert self.top_node is None, \
                "Error: top_node isn't None!"

    def score(self, full_report=False):
        """
        Returns a number between 0 and 1, indicating how suboptimal the tree
        is. The lower, the better. Roughly, this number represents the
        fraction of flawed Intervals in the tree.
        :rtype: float
        """
        if len(self) <= 2:
            return 0.0

        n = len(self)
        m = self.top_node.count_nodes()

        def s_center_score():
            """
            Returns a normalized score, indicating roughly how many times
            intervals share s_center with other intervals. Output is full-scale
            from 0 to 1.
            :rtype: float
            """
            raw = n - m
            maximum = n - 1
            return raw / float(maximum)

        report = {
            "depth": self.top_node.depth_score(n, m),
            "s_center": s_center_score(),
        }
        cumulative = max(report.values())
        report["_cumulative"] = cumulative
        if full_report:
            return report
        return cumulative


    def __getitem__(self, index):
        """
        Returns a set of all intervals overlapping the given index or 
        slice.
        
        Completes in O(k * log(n) + m) time, where:
          * n = size of the tree
          * m = number of matches
          * k = size of the search range (this is 1 for a point)
        :rtype: set of Interval
        """
        try:
            start, stop = index.start, index.stop
            if start is None:
                start = self.begin()
                if stop is None:
                    return set(self)
            if stop is None:
                stop = self.end()
            return self.search(start, stop)
        except AttributeError:
            return self.search(index)
    
    def __setitem__(self, index, value):
        """
        Adds a new interval to the tree. A shortcut for
        add(Interval(index.start, index.stop, value)).
        
        If an identical Interval object with equal range and data 
        already exists, does nothing.
        
        Completes in O(log n) time.
        """
        self.addi(index.start, index.stop, value)

    def __delitem__(self, point):
        """
        Delete all items overlapping point.
        """
        self.remove_overlap(point)

    def __contains__(self, item):
        """
        Returns whether item exists as an Interval in the tree.
        This method only returns True for exact matches; for
        overlaps, see the overlaps() method.
        
        Completes in O(1) time.
        :rtype: bool
        """
        # Removed point-checking code; it might trick the user into
        # thinking that this is O(1), which point-checking isn't.
        #if isinstance(item, Interval):
        return item in self.all_intervals
        #else:
        #    return self.contains_point(item)
    
    def containsi(self, begin, end, data=None):
        """
        Shortcut for (Interval(begin, end, data) in tree).
        
        Completes in O(1) time.
        :rtype: bool
        """
        return Interval(begin, end, data) in self
    
    def __iter__(self):
        """
        Returns an iterator over all the intervals in the tree.
        
        Completes in O(1) time.
        :rtype: collections.Iterable[Interval]
        """
        return self.all_intervals.__iter__()
    iter = __iter__
    
    def __len__(self):
        """
        Returns how many intervals are in the tree.
        
        Completes in O(1) time.
        :rtype: int
        """
        return len(self.all_intervals)
    
    def __eq__(self, other):
        """
        Whether two IntervalTrees are equal.
        
        Completes in O(n) time if sizes are equal; O(1) time otherwise.
        :rtype: bool
        """
        return (
            isinstance(other, IntervalTree) and 
            self.all_intervals == other.all_intervals
        )
    
    def __repr__(self):
        """
        :rtype: str
        """
        ivs = sorted(self)
        if not ivs:
            return "IntervalTree()"
        else:
            return "IntervalTree({0})".format(ivs)

    __str__ = __repr__

    def __reduce__(self):
        """
        For pickle-ing.
        :rtype: tuple
        """
        return IntervalTree, (sorted(self.all_intervals),)
Пример #50
0
class Alert(object):
    """description of class"""
    def __init__(self, config):
        self.config_=config
        self.statsStateDict_=SortedDict()
        self.monitorStateDict_=SortedDict()
        return
    def displayAlerts(self):
        for time in self.statsStateDict_:
            if self.statsStateDict_[time].printed_ == bool(False):
                print(f'stats for last {self.config_.statsTime_} secs for {time} is {self.statsStateDict_[time]}')
        for time in self.monitorStateDict_:
            if self.monitorStateDict_[time].printed_ == bool(False):
                print(f'{self.monitorStateDict_[time]}')
        
        return
    def updateAlertState(self,logState):
        if not logState.keys():
            return

        lastInsertedTime = logState.keys()[-1]
        if lastInsertedTime in self.statsStateDict_.keys():
            logline = logState.getElem(int(lastInsertedTime))[-1]#last inserted element in the list
            self.statsStateDict_[lastInsertedTime].increment_total_hits()
            self.statsStateDict_[lastInsertedTime].increment_hits_per_host(logline.getRemoteHost())
            self.statsStateDict_[lastInsertedTime].increment_hits_per_user(logline.getUser())
            self.statsStateDict_[lastInsertedTime].increment_hits_per_request(logline.getRequest())
            self.statsStateDict_[lastInsertedTime].increment_hits_per_status(logline.getStatus())
            self.statsStateDict_[lastInsertedTime].increment_hits_per_section(logline.getRequest())
        else:
            statsTimeBack = logState.getLowerBound(int(lastInsertedTime)-int(self.config_.statsTime_))
            diffStartTime = statsTimeBack[0]
            if int(lastInsertedTime)-int(diffStartTime) == self.config_.statsTime_:
                count=0
                currentStatsState=StatsState()
                for time in statsTimeBack:
                    temp = logState.getElem(time)
                    count += len(temp)
                    for logline in temp:
                        currentStatsState.increment_hits_per_host(logline.getRemoteHost())
                        currentStatsState.increment_hits_per_user(logline.getUser())
                        currentStatsState.increment_hits_per_request(logline.getRequest())
                        currentStatsState.increment_hits_per_status(logline.getStatus())
                        currentStatsState.increment_hits_per_section(logline.getRequest())
                currentStatsState.set_total_hits(count)
                self.statsStateDict_._setitem(lastInsertedTime,currentStatsState)
   
        if int(lastInsertedTime) not in self.monitorStateDict_.keys():
            MonitorTimeBack = logState.getLowerBound(int(lastInsertedTime)-int(self.config_.monitorTime_))
            diffStartTime = MonitorTimeBack[0]
            if int(lastInsertedTime) - int(diffStartTime) == int(self.config_.monitorTime_):
                count =0
                currentMonitorState=MonitorState(int(lastInsertedTime),count,self.config_)
                for time in MonitorTimeBack:
                    temp = logState.getElem(time)
                    count += len(temp)
                currentMonitorState.totalHits_=count
                self.monitorStateDict_[lastInsertedTime]=currentMonitorState
        else:
            self.monitorStateDict_.get(int(lastInsertedTime)).totalHits_ +=1
            self.monitorStateDict_.get(int(lastInsertedTime)).printed_=False
Пример #51
0
            return self.fr <= t and t < self.to
        return self.fr == t.fr and self.to == t.to
    def __ne__(self, t):
        return not self.__eq__(t)
    def __hash__(self):
        return hash(str(self))
    def __str__(self):
        return "[{}-{}]".format(self.fr, self.to)

from sortedcontainers import SortedDict

d = SortedDict()
d[Range(50,60)] = None
d[Range(40,50)] = None
d[Range(10,20)] = None
d[Range(80,90)] = None
d[Range(110,120)] = None
for k in d.keys():
    print(k)
print("==============================")
print("85", d.iloc[d.index(85)])
print("40", d.iloc[d.index(40)])
print("49", d.iloc[d.index(49)])
print("50", d.iloc[d.index(50)])
print("59", d.iloc[d.index(59)])
try:
    d.index(60)
except ValueError:
    print("60 is not in d")
d[60]
Пример #52
0
class KeyedRegion:
    """
    KeyedRegion keeps a mapping between stack offsets and all objects covering that offset. It assumes no variable in
    this region overlap with another variable in this region.

    Registers and function frames can all be viewed as a keyed region.
    """

    __slots__ = ('_storage', '_object_mapping', '_phi_node_contains')

    def __init__(self, tree=None, phi_node_contains=None):
        self._storage = SortedDict() if tree is None else tree
        self._object_mapping = weakref.WeakValueDictionary()
        self._phi_node_contains = phi_node_contains

    def __getstate__(self):
        return self._storage, dict(
            self._object_mapping), self._phi_node_contains

    def __setstate__(self, s):
        self._storage, om, self._phi_node_contains = s
        self._object_mapping = weakref.WeakValueDictionary(om)

    def _get_container(self, offset):
        try:
            base_offset = next(
                self._storage.irange(maximum=offset, reverse=True))
        except StopIteration:
            return offset, None
        else:
            container = self._storage[base_offset]
            if container.includes(offset):
                return base_offset, container
            return offset, None

    def __contains__(self, offset):
        """
        Test if there is at least one variable covering the given offset.

        :param offset:
        :return:
        """

        if type(offset) is not int:
            raise TypeError("KeyedRegion only accepts concrete offsets.")

        return self._get_container(offset)[1] is not None

    def __len__(self):
        return len(self._storage)

    def __iter__(self):
        return iter(self._storage.values())

    def __eq__(self, other):
        if set(self._storage.keys()) != set(other._storage.keys()):
            return False

        for k, v in self._storage.items():
            if v != other._storage[k]:
                return False

        return True

    def copy(self):
        if not self._storage:
            return KeyedRegion(phi_node_contains=self._phi_node_contains)

        kr = KeyedRegion(phi_node_contains=self._phi_node_contains)
        for key, ro in self._storage.items():
            kr._storage[key] = ro.copy()
        kr._object_mapping = self._object_mapping.copy()
        return kr

    def merge(self, other, replacements=None):
        """
        Merge another KeyedRegion into this KeyedRegion.

        :param KeyedRegion other: The other instance to merge with.
        :return: None
        """

        # TODO: is the current solution not optimal enough?
        for _, item in other._storage.items():  # type: RegionObject
            for so in item.stored_objects:  # type: StoredObject
                if replacements and so.obj in replacements:
                    so = StoredObject(so.start, replacements[so.obj], so.size)
                self._object_mapping[so.obj_id] = so
                self.__store(so, overwrite=False)

        return self

    def merge_to_top(self, other, replacements=None, top=None):
        """
        Merge another KeyedRegion into this KeyedRegion, but mark all variables with different values as TOP.

        :param other:   The other instance to merge with.
        :param replacements:
        :return:        self
        """

        for _, item in other._storage.items():  # type: RegionObject
            for so in item.stored_objects:  # type: StoredObject
                if replacements and so.obj in replacements:
                    so = StoredObject(so.start, replacements[so.obj], so.size)
                self._object_mapping[so.obj_id] = so
                self.__store(so, overwrite=False, merge_to_top=True, top=top)

        return self

    def replace(self, replacements):
        """
        Replace variables with other variables.

        :param dict replacements:   A dict of variable replacements.
        :return:                    self
        """

        for old_var, new_var in replacements.items():
            old_var_id = id(old_var)
            if old_var_id in self._object_mapping:
                # FIXME: we need to check if old_var still exists in the storage
                old_so = self._object_mapping[old_var_id]  # type: StoredObject
                self._store(old_so.start, new_var, old_so.size, overwrite=True)

        return self

    def dbg_repr(self):
        """
        Get a debugging representation of this keyed region.
        :return: A string of debugging output.
        """
        keys = self._storage.keys()
        offset_to_vars = {}

        for key in sorted(keys):
            ro = self._storage[key]
            variables = [obj.obj for obj in ro.stored_objects]
            offset_to_vars[key] = variables

        s = []
        for offset, variables in offset_to_vars.items():
            s.append("Offset %#x: %s" % (offset, variables))
        return "\n".join(s)

    def add_variable(self, start, variable):
        """
        Add a variable to this region at the given offset.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.add_object(start, variable, size)

    def add_object(self, start, obj, object_size):
        """
        Add/Store an object to this region at the given offset.

        :param start:
        :param obj:
        :param int object_size: Size of the object
        :return:
        """

        self._store(start, obj, object_size, overwrite=False)

    def set_variable(self, start, variable):
        """
        Add a variable to this region at the given offset, and remove all other variables that are fully covered by
        this variable.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.set_object(start, variable, size)

    def set_object(self, start, obj, object_size):
        """
        Add an object to this region at the given offset, and remove all other objects that are fully covered by this
        object.

        :param start:
        :param obj:
        :param object_size:
        :return:
        """

        self._store(start, obj, object_size, overwrite=True)

    def get_base_addr(self, addr):
        """
        Get the base offset (the key we are using to index objects covering the given offset) of a specific offset.

        :param int addr:
        :return:
        :rtype:  int or None
        """

        base_addr, container = self._get_container(addr)
        if container is None:
            return None
        else:
            return base_addr

    def get_variables_by_offset(self, start):
        """
        Find variables covering the given region offset.

        :param int start:
        :return: A list of stack variables.
        :rtype:  set
        """

        _, container = self._get_container(start)
        if container is None:
            return []
        else:
            return container.internal_objects

    def get_objects_by_offset(self, start):
        """
        Find objects covering the given region offset.

        :param start:
        :return:
        """

        _, container = self._get_container(start)
        if container is None:
            return set()
        else:
            return container.internal_objects

    #
    # Private methods
    #

    def _store(self, start, obj, size, overwrite=False):
        """
        Store a variable into the storage.

        :param int start: The beginning address of the variable.
        :param obj: The object to store.
        :param int size: Size of the object to store.
        :param bool overwrite: Whether existing objects should be overwritten or not.
        :return: None
        """

        stored_object = StoredObject(start, obj, size)
        self._object_mapping[stored_object.obj_id] = stored_object
        self.__store(stored_object, overwrite=overwrite)

    def __store(self,
                stored_object,
                overwrite=False,
                merge_to_top=False,
                top=None):
        """
        Store a variable into the storage.

        :param StoredObject stored_object: The descriptor describing start address and the variable.
        :param bool overwrite:  Whether existing objects should be overwritten or not. True to make a strong update,
                                False to make a weak update.
        :return: None
        """

        start = stored_object.start
        object_size = stored_object.size
        end = start + object_size

        # region items in the middle
        overlapping_items = list(self._storage.irange(start, end - 1))

        # is there a region item that begins before the start and overlaps with this variable?
        floor_key, floor_item = self._get_container(start)
        if floor_item is not None and floor_key not in overlapping_items:
            # insert it into the beginning
            overlapping_items.insert(0, floor_key)

        # scan through the entire list of region items, split existing regions and insert new regions as needed
        to_update = {start: RegionObject(start, object_size, {stored_object})}
        last_end = start

        for floor_key in overlapping_items:
            item = self._storage[floor_key]
            if item.start < start:
                # we need to break this item into two
                a, b = item.split(start)
                if overwrite:
                    b.set_object(stored_object)
                else:
                    self._add_object_with_check(b,
                                                stored_object,
                                                merge_to_top=merge_to_top,
                                                top=top)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            elif item.start > last_end:
                # there is a gap between the last item and the current item
                # fill in the gap
                new_item = RegionObject(last_end, item.start - last_end,
                                        {stored_object})
                to_update[new_item.start] = new_item
                last_end = new_item.end
            elif item.end > end:
                # we need to split this item into two
                a, b = item.split(end)
                if overwrite:
                    a.set_object(stored_object)
                else:
                    self._add_object_with_check(a,
                                                stored_object,
                                                merge_to_top=merge_to_top,
                                                top=top)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            else:
                if overwrite:
                    item.set_object(stored_object)
                else:
                    self._add_object_with_check(item,
                                                stored_object,
                                                merge_to_top=merge_to_top,
                                                top=top)
                to_update[item.start] = item

        self._storage.update(to_update)

    def _is_overlapping(self, start, variable):

        if variable.size is not None:
            # make sure this variable does not overlap with any other variable
            end = start + variable.size
            try:
                prev_offset = next(
                    self._storage.irange(maximum=end - 1, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                if start <= prev_offset < end:
                    return True
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if start < prev_offset + prev_item_size < end:
                    return True
        else:
            try:
                prev_offset = next(
                    self._storage.irange(maximum=start, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if prev_offset <= start < prev_offset + prev_item_size:
                    return True

        return False

    def _add_object_with_check(self,
                               item,
                               stored_object,
                               merge_to_top=False,
                               top=None):
        if len({stored_object.obj} | item.internal_objects) > 1:
            if merge_to_top:
                item.set_object(
                    StoredObject(stored_object.start, top, stored_object.size))
                return

            if self._phi_node_contains is not None:
                # check if `item` is a phi node that contains stored_object.obj
                for so in item.internal_objects:
                    if self._phi_node_contains(so, stored_object.obj):
                        # yes! so we want to skip this object
                        return
                # check if `stored_object.obj` is a phi node that contains item.internal_objects
                if all(
                        self._phi_node_contains(stored_object.obj, o)
                        for o in item.internal_objects):
                    # yes!
                    item.set_object(stored_object)
                    return

            # l.warning("Overlapping objects %s.", str({stored_object.obj} | item.internal_objects))
            # import ipdb; ipdb.set_trace()
        item.add_object(stored_object)
Пример #53
0
def resolve_conflicts(pfam_hit_dict,minDomSize = 9,verbose=False):
    '''
    :param pfam_hit_dict: dictionary of hits for the gene in the following format
    hit start,hit end : int
    hit id : str
    score, model coverage percent : float
    {(hit start,hit end):('hit id',score,model coverage percent)}
    :param minDomSize: int, the minimum window size that will be considered a domain
    :return:
    a sorted dictionary with the position of the hit as the keys and ('hit id',score,model coverage percent)
    '''
    # initialize output
    gene_hits = SortedDict()
    redoFlag = True
    while redoFlag:
        if verbose: print("Sorting through intervals", pfam_hit_dict)
        redoFlag = False
        intervals_scores = [(key,value[1]) for key,value in pfam_hit_dict.items()]
        # sort intervals from pfam hits by score and place the highest score first
        intervals_scores.sort(key=itemgetter(1),reverse=True)
        # initialize intersect tree for quick overlap search
        intersectTree = IntervalTree()
        #add the intervals with the highest scores first
        for (interval,score) in intervals_scores:
            intervalStart = interval[0]
            intervalEnd = interval[1]
            intervalLength = intervalEnd-intervalStart+1
            # if the interval is less than the minimum domain size don't bother
            if intervalLength > minDomSize:
                intersectingIntervals = [(x.start,x.end) for x in intersectTree.find(intervalStart,intervalEnd)]
                overLapFlag = False
                # for every interval that you're adding resolve the overlapping intervals
                while len(intersectingIntervals) > 0 and intervalLength > 1:

                    start,end = intersectingIntervals[0]

                    # interval completely covers existing coverage, break up into two intervals and redo the process
                    if (intervalStart < start and intervalEnd > end):
                        if verbose: print("Split Interval", interval,intersectingIntervals, pfam_hit_dict[interval])
                        left_scale = calculate_window((intervalStart,start-1))/intervalLength
                        right_scale = calculate_window((end+1,intervalEnd))/intervalLength
                        pfam_hit_dict[(intervalStart,start-1)] = (pfam_hit_dict[interval][0],
                                                                  pfam_hit_dict[interval][1],
                                                                  pfam_hit_dict[interval][2] * left_scale)
                        pfam_hit_dict[(end+1,intervalEnd)] = (pfam_hit_dict[interval][0],
                                                              pfam_hit_dict[interval][1],
                                                              pfam_hit_dict[interval][2] * right_scale)
                        # delete original hit and iterate
                        del pfam_hit_dict[interval]
                        redoFlag = True
                        break
                    else:
                        #completely in the interval
                        if (intervalStart >= start and intervalEnd <= end):
                            #if completely overlapping then ignore since we already sorted by score
                            overLapFlag = True
                            break
                        #intersection covers the left hand side of the interval
                        elif intervalStart >= start:
                            intervalStart = end + 1
                        #intersection covers the right hand side of the interval
                        elif intervalEnd <= end:
                            intervalEnd = start - 1
                            # recalculate the interval length and see if there are still intersecting intervals
                        intervalLength = intervalEnd-intervalStart+1
                        intersectingIntervals = [(x.start,x.end) for x in intersectTree.find(intervalStart,intervalEnd)]

                if redoFlag:
                    if verbose: print("Exiting For Loop to Reinitialize",pfam_hit_dict)
                    break
                # if loop did not break because of an overlap add the annotation after resolving overlap,
                # check for minimum length after you merge intervals
                elif not overLapFlag and intervalLength > minDomSize:
                    if verbose: print("Adding Hit",(intervalStart,intervalEnd),pfam_hit_dict[interval][0])
                    # scale the hitCoverage based on the reduction this works since interval is a tuple and isn't mutated
                    hitCoverage = pfam_hit_dict[interval][2]*(intervalLength/(interval[1]-interval[0]+1.))
                    gene_hits[(intervalStart,intervalEnd)] = (pfam_hit_dict[interval][0],
                                                              pfam_hit_dict[interval][1],
                                                              hitCoverage)
                    intersectTree.add_interval(Interval(float(intervalStart),intervalEnd))
    if verbose: print("Merging Hits")
    # Merge Windows Right Next to one another that have the same pFam ID,
    # redoFlag: need to restart the process after a successful merge
    redoFlag = True
    while redoFlag:
        for idx in range(len(gene_hits)-1):
            left_hit = gene_hits.keys()[idx]
            right_hit = gene_hits.keys()[idx+1]
            left_window_size = calculate_window(left_hit)
            right_window_size = calculate_window(right_hit)
            merged_window_size = calculate_window((left_hit[0],right_hit[1]))
            new_coverage = (gene_hits[left_hit][2] + gene_hits[right_hit][2])*\
                           (left_window_size+ right_window_size)/merged_window_size
            # Will merge a hit under the following conditions:
            # 1. Gap between the two hits is less than the minimum domain
            # 2. Cumulative coverage of the two hits is less than 1 (this avoids merging repeats together)
            if right_hit[0]-left_hit[1] < minDomSize and gene_hits[left_hit][0] == gene_hits[right_hit][0] \
                    and new_coverage < 1:
                gene_hits[(left_hit[0],right_hit[1])] = (gene_hits[left_hit][0],
                                                         left_window_size/merged_window_size * gene_hits[left_hit][1] +
                                                         right_window_size/merged_window_size * gene_hits[right_hit][1],
                                                         new_coverage)
                redoFlag = True
                del gene_hits[left_hit]
                del gene_hits[right_hit]
                if verbose: print("Merged", left_hit,right_hit)
                break
        else:
            redoFlag = False
    if verbose: print("Deleting Domains Under Minimum Domain Size")
    # Finally check if any of the domains are less than the minimum domain size
    keysToDelete = [coordinates for coordinates in gene_hits.keys() if calculate_window(coordinates) < minDomSize]
    for key in keysToDelete:
        del gene_hits[key]
        if verbose: print("Deleting",key)
    if verbose: print("Final Annotation", gene_hits)
    return gene_hits
Пример #54
0
def test_keys():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert list(temp.keys()) == [key for key, pos in mapping]
Пример #55
0
class KeyedRegion:
    """
    KeyedRegion keeps a mapping between stack offsets and all objects covering that offset. It assumes no variable in
    this region overlap with another variable in this region.

    Registers and function frames can all be viewed as a keyed region.
    """

    __slots__ = ('_storage', '_object_mapping', '_phi_node_contains' )

    def __init__(self, tree=None, phi_node_contains=None):
        self._storage = SortedDict() if tree is None else tree
        self._object_mapping = weakref.WeakValueDictionary()
        self._phi_node_contains = phi_node_contains

    def __getstate__(self):
        return self._storage, dict(self._object_mapping), self._phi_node_contains

    def __setstate__(self, s):
        self._storage, om, self._phi_node_contains = s
        self._object_mapping = weakref.WeakValueDictionary(om)

    def _get_container(self, offset):
        try:
            base_offset = next(self._storage.irange(maximum=offset, reverse=True))
        except StopIteration:
            return offset, None
        else:
            container = self._storage[base_offset]
            if container.includes(offset):
                return base_offset, container
            return offset, None

    def __contains__(self, offset):
        """
        Test if there is at least one variable covering the given offset.

        :param offset:
        :return:
        """

        if type(offset) is not int:
            raise TypeError("KeyedRegion only accepts concrete offsets.")

        return self._get_container(offset)[1] is not None

    def __len__(self):
        return len(self._storage)

    def __iter__(self):
        return iter(self._storage.values())

    def __eq__(self, other):
        if set(self._storage.keys()) != set(other._storage.keys()):
            return False

        for k, v in self._storage.items():
            if v != other._storage[k]:
                return False

        return True

    def copy(self):
        if not self._storage:
            return KeyedRegion(phi_node_contains=self._phi_node_contains)

        kr = KeyedRegion(phi_node_contains=self._phi_node_contains)
        for key, ro in self._storage.items():
            kr._storage[key] = ro.copy()
        kr._object_mapping = self._object_mapping.copy()
        return kr

    def merge(self, other, replacements=None):
        """
        Merge another KeyedRegion into this KeyedRegion.

        :param KeyedRegion other: The other instance to merge with.
        :return: None
        """

        # TODO: is the current solution not optimal enough?
        for _, item in other._storage.items():  # type: RegionObject
            for so in item.stored_objects:  # type: StoredObject
                if replacements and so.obj in replacements:
                    so = StoredObject(so.start, replacements[so.obj], so.size)
                self._object_mapping[so.obj_id] = so
                self.__store(so, overwrite=False)

        return self

    def replace(self, replacements):
        """
        Replace variables with other variables.

        :param dict replacements:   A dict of variable replacements.
        :return:                    self
        """

        for old_var, new_var in replacements.items():
            old_var_id = id(old_var)
            if old_var_id in self._object_mapping:
                # FIXME: we need to check if old_var still exists in the storage
                old_so = self._object_mapping[old_var_id]  # type: StoredObject
                self._store(old_so.start, new_var, old_so.size, overwrite=True)

        return self

    def dbg_repr(self):
        """
        Get a debugging representation of this keyed region.
        :return: A string of debugging output.
        """
        keys = self._storage.keys()
        offset_to_vars = { }

        for key in sorted(keys):
            ro = self._storage[key]
            variables = [ obj.obj for obj in ro.stored_objects ]
            offset_to_vars[key] = variables

        s = [ ]
        for offset, variables in offset_to_vars.items():
            s.append("Offset %#x: %s" % (offset, variables))
        return "\n".join(s)

    def add_variable(self, start, variable):
        """
        Add a variable to this region at the given offset.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.add_object(start, variable, size)

    def add_object(self, start, obj, object_size):
        """
        Add/Store an object to this region at the given offset.

        :param start:
        :param obj:
        :param int object_size: Size of the object
        :return:
        """

        self._store(start, obj, object_size, overwrite=False)

    def set_variable(self, start, variable):
        """
        Add a variable to this region at the given offset, and remove all other variables that are fully covered by
        this variable.

        :param int start:
        :param SimVariable variable:
        :return: None
        """

        size = variable.size if variable.size is not None else 1

        self.set_object(start, variable, size)

    def set_object(self, start, obj, object_size):
        """
        Add an object to this region at the given offset, and remove all other objects that are fully covered by this
        object.

        :param start:
        :param obj:
        :param object_size:
        :return:
        """

        self._store(start, obj, object_size, overwrite=True)

    def get_base_addr(self, addr):
        """
        Get the base offset (the key we are using to index objects covering the given offset) of a specific offset.

        :param int addr:
        :return:
        :rtype:  int or None
        """

        base_addr, container = self._get_container(addr)
        if container is None:
            return None
        else:
            return base_addr

    def get_variables_by_offset(self, start):
        """
        Find variables covering the given region offset.

        :param int start:
        :return: A list of stack variables.
        :rtype:  set
        """

        _, container = self._get_container(start)
        if container is None:
            return []
        else:
            return container.internal_objects

    def get_objects_by_offset(self, start):
        """
        Find objects covering the given region offset.

        :param start:
        :return:
        """

        _, container = self._get_container(start)
        if container is None:
            return set()
        else:
            return container.internal_objects

    #
    # Private methods
    #

    def _store(self, start, obj, size, overwrite=False):
        """
        Store a variable into the storage.

        :param int start: The beginning address of the variable.
        :param obj: The object to store.
        :param int size: Size of the object to store.
        :param bool overwrite: Whether existing objects should be overwritten or not.
        :return: None
        """

        stored_object = StoredObject(start, obj, size)
        self._object_mapping[stored_object.obj_id] = stored_object
        self.__store(stored_object, overwrite=overwrite)

    def __store(self, stored_object, overwrite=False):
        """
        Store a variable into the storage.

        :param StoredObject stored_object: The descriptor describing start address and the variable.
        :param bool overwrite:  Whether existing objects should be overwritten or not. True to make a strong update,
                                False to make a weak update.
        :return: None
        """

        start = stored_object.start
        object_size = stored_object.size
        end = start + object_size

        # region items in the middle
        overlapping_items = list(self._storage.irange(start, end-1))

        # is there a region item that begins before the start and overlaps with this variable?
        floor_key, floor_item = self._get_container(start)
        if floor_item is not None and floor_key not in overlapping_items:
            # insert it into the beginning
            overlapping_items.insert(0, floor_key)

        # scan through the entire list of region items, split existing regions and insert new regions as needed
        to_update = {start: RegionObject(start, object_size, {stored_object})}
        last_end = start

        for floor_key in overlapping_items:
            item = self._storage[floor_key]
            if item.start < start:
                # we need to break this item into two
                a, b = item.split(start)
                if overwrite:
                    b.set_object(stored_object)
                else:
                    self._add_object_with_check(b, stored_object)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            elif item.start > last_end:
                # there is a gap between the last item and the current item
                # fill in the gap
                new_item = RegionObject(last_end, item.start - last_end, {stored_object})
                to_update[new_item.start] = new_item
                last_end = new_item.end
            elif item.end > end:
                # we need to split this item into two
                a, b = item.split(end)
                if overwrite:
                    a.set_object(stored_object)
                else:
                    self._add_object_with_check(a, stored_object)
                to_update[a.start] = a
                to_update[b.start] = b
                last_end = b.end
            else:
                if overwrite:
                    item.set_object(stored_object)
                else:
                    self._add_object_with_check(item, stored_object)
                to_update[item.start] = item

        self._storage.update(to_update)

    def _is_overlapping(self, start, variable):

        if variable.size is not None:
            # make sure this variable does not overlap with any other variable
            end = start + variable.size
            try:
                prev_offset = next(self._storage.irange(maximum=end-1, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                if start <= prev_offset < end:
                    return True
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if start < prev_offset + prev_item_size < end:
                    return True
        else:
            try:
                prev_offset = next(self._storage.irange(maximum=start, reverse=True))
            except StopIteration:
                prev_offset = None

            if prev_offset is not None:
                prev_item = self._storage[prev_offset][0]
                prev_item_size = prev_item.size if prev_item.size is not None else 1
                if prev_offset <= start < prev_offset + prev_item_size:
                    return True

        return False

    def _add_object_with_check(self, item, stored_object):
        if len({stored_object.obj} | item.internal_objects) > 1:
            if self._phi_node_contains is not None:
                # check if `item` is a phi node that contains stored_object.obj
                for so in item.internal_objects:
                    if self._phi_node_contains(so, stored_object.obj):
                        # yes! so we want to skip this object
                        return
                # check if `stored_object.obj` is a phi node that contains item.internal_objects
                if all(self._phi_node_contains(stored_object.obj, o) for o in item.internal_objects):
                    # yes!
                    item.set_object(stored_object)
                    return

            l.warning("Overlapping objects %s.", str({stored_object.obj} | item.internal_objects))
            # import ipdb; ipdb.set_trace()
        item.add_object(stored_object)
Пример #56
0
class RegularTreeGossip(RegularTree):

	def __init__(self, degree = None, spreading_time = None):
		super(RegularTreeGossip, self).__init__(degree, spreading_time)
		self.adversary = -1
		self.adversary_timestamps = SortedDict()

		self.add_node(self.adversary, infected = False)
		

	def add_edge(self, u, v):
		super(RegularTree, self).add_edge(u, v)
		super(RegularTree, self).add_edge(u, self.adversary)
		super(RegularTree, self).add_edge(v, self.adversary)

	def add_node(self, u, attr_dict = None, **attr):
		super(RegularTree, self).add_node(u, attr)
		if not (u == self.adversary):
			super(RegularTree, self).add_edge(u, self.adversary)
		self.max_node = max(self.nodes())

	def generate_timestamp_dict(self):
		''' Creates a dict with nodes as keys and timestamps as values '''
		timestamp_dict = {}
		for key in self.adversary_timestamps.keys():
			for node in self.adversary_timestamps[key]:
				timestamp_dict[node] = key
		return timestamp_dict

	def draw_plot(self):
		values = ['r' for i in self.nodes()]
		values[-1] = 'b'

		pos=nx.circular_layout(self) # positions for all nodes

		nx.draw(self, pos = pos, node_color = values)

		labels={}
		for i in self.nodes():
			labels[i] = str(i)
		nx.draw_networkx_labels(self,pos,labels,font_size=16)

		plt.show()

	def spread_message(self):
		t = 1
		candidates = []
		reached_adversary = False

		
		while (t <= self.spreading_time):
			current_active = [item for item in self.active]
			for node in current_active:
				# Check that all the nodes have enough neighbors
				if ((self.degree(node) < (self.tree_degree + 1) and self.has_edge(node, self.adversary)) or
				   (self.degree(node) < (self.tree_degree) and (not self.has_edge(node, self.adversary)))): 
					if (self.degree(node) < (self.tree_degree + 1) and self.has_edge(node, self.adversary)):
						num_missing = (self.tree_degree + 1) - self.degree(node)
					else:
						num_missing = self.tree_degree - self.degree(node)
					new_nodes = range(self.max_node + 1, self.max_node + num_missing + 1)
					self.add_edges(node, new_nodes)
				
				# print 'adjacency: ', self.edges()

				# Spread to the active nodes' uninfected neighbors
				uninfected_neighbors = self.get_uninfected_neighbors(node)
				# print 'uninfected_neighbors', uninfected_neighbors, 'current_active', current_active, 'node', node
				to_infect = random.choice(uninfected_neighbors)
				self.infect_node(node, to_infect)
				# print 'node ', node, ' infected ', to_infect
				
				
				if to_infect == self.adversary:
					# self.adversary_timestamps.append([node,t])
					if (t in self.adversary_timestamps):
						self.adversary_timestamps[t] += [node]
					else:
						self.adversary_timestamps[t] = [node]
					# self.adversary_timestamps += [t]
					# candidates += [node]
					# reached_adversary = True
			t += 1