Exemplo n.º 1
0
class PrioritizedIntensity(object):
    _MIN_VALUE = 0.005

    def __init__(self):
        self._values = SortedDict()

    def set(self, value, priority=100):
        value = float(value)
        if value < self._MIN_VALUE and priority in self._values:
            del self._values[priority]
        else:
            self._values[priority] = value

    def eval(self):
        if not self._values:
            return 0.0
        return self._values[self._values.iloc[- 1]]

    def top_priority(self):
        if not self._values:
            return 0
        return self._values.keys()[len(self._values) - 1]

    def reset(self):
        self._values.clear()
Exemplo n.º 2
0
class MemTable:
    """
    Internal data structure built on top of a red-black BST. It holds entries in sorted order and should be used in
    conjunction with jumpDB.DB
    """
    def __init__(self, max_size):
        self._entries = SortedDict()
        self.max_size = max_size

    def __setitem__(self, key, value):
        self._entries[key] = value

    def __len__(self):
        return len(self._entries)

    def __getitem__(self, item):
        return self._entries[item]

    def clear(self):
        self._entries.clear()

    def __contains__(self, item):
        return item in self._entries

    def capacity_reached(self):
        return len(self._entries) >= self.max_size

    def __iter__(self):
        for key, value in self._entries.items():
            yield (key, value)
Exemplo n.º 3
0
def test_clear():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert len(temp) == 26
    assert list(temp.items()) == mapping
    temp.clear()
    assert len(temp) == 0
def test_clear():
    mapping = [(val, pos) for pos, val in enumerate(string.ascii_lowercase)]
    temp = SortedDict(mapping)
    assert len(temp) == 26
    assert list(temp.items()) == mapping
    temp.clear()
    assert len(temp) == 0
Exemplo n.º 5
0
class OrderBasedBook:
    def __init__(self):
        self._bids = SortedDict(operator.neg)
        self._asks = SortedDict()

    def add_order(self, side, order_id, price, qty):
        lvl = self._fetch_level(side, price)
        lvl.add_order(order_id, qty)

    def change_order(self, side, order_id, price, new_qty):
        """Change quantity for order.

        Returns True if order exists in book, false otherwise.
        """
        lvl = self._fetch_level(side, price)
        return lvl.change_order(order_id, new_qty)

    def match_order(self, side, order_id, price, trade_qty):
        lvl = self._fetch_level(side, price)
        lvl.match_order(order_id, trade_qty)
        if lvl.empty:
            self._remove_level(side, price)

    def remove_order(self, side, order_id, price):
        lvl = self._fetch_level(side, price)
        removed = lvl.remove_order(order_id)
        if lvl.empty:
            self._remove_level(side, price)
        return removed

    def clear(self):
        self._bids.clear()
        self._asks.clear()

    def make_book(self, sequence):
        """Return ``market_data.Book`` for OrderBasedBook."""
        return Book(
                sequence=sequence,
                bids=list(self._bids.values()),
                asks=list(self._asks.values()))

    def _fetch_level(self, side, price):
        levels = self._choose_side(side)
        return OrderBasedBook._get_level(price, levels)

    def _remove_level(self, side, price):
        levels = self._choose_side(side)
        levels.pop(price, None)

    def _choose_side(self, side):
        return self._bids if side == Side.BID else self._asks

    @staticmethod
    def _get_level(price, levels):
        if price not in levels:
            levels[price] = OrderBasedLevel(price=price)
        return levels[price]
Exemplo n.º 6
0
class QtDictListModel(QAbstractListModel):
    def __init__(self):
        QAbstractListModel.__init__(self)
        self._items = SortedDict()

    def role(self, item, role):
        return item

    def rowCount(self, parent):
        if parent.isValid():
            return 0
        return len(self._items)

    def from_index(self, index):
        if not index.isValid() or index.row() >= len(self._items):
            return None
        return self._items.peekitem(index.row())[1]

    def data(self, index, role):
        item = self.from_index(index)
        if item is None:
            return None
        return self.role(item, role)

    def _add(self, key, item):
        assert key not in self._items
        next_index = self._items.bisect_left(key)
        self.beginInsertRows(QModelIndex(), next_index, next_index)
        self._items[key] = item
        self.endInsertRows()

    # TODO - removal is O(n).
    def _remove(self, key):
        assert key in self._items
        item_index = self._items.index(key)
        self.beginRemoveRows(QModelIndex(), item_index, item_index)
        del self._items[key]
        self.endRemoveRows()

    def _clear(self):
        self.beginRemoveRows(QModelIndex(), 0, len(self._items) - 1)
        self._items.clear()
        self.endRemoveRows()

    # O(n). Rework if it's too slow.
    def _update(self, key, roles=None):
        item_index = self._items.index(key)
        index = self.index(item_index, 0)
        if roles is None:
            self.dataChanged.emit(index, index)
        else:
            self.dataChanged.emit(index, index, roles)
def test_init():
    sdict = SortedDict()
    sdict._check()

    sdict = SortedDict(load=17)
    sdict._check()

    sdict = SortedDict((val, -val) for val in range(10000))
    sdict._check()
    assert all(key == -val for key, val in sdict.iteritems())

    sdict.clear()
    sdict._check()
    assert len(sdict) == 0

    sdict = SortedDict.fromkeys(range(1000), None)
    assert all(sdict[key] == None for key in range(1000))
def test_init():
    sdict = SortedDict()
    sdict._check()

    sdict = SortedDict(load=17)
    sdict._check()

    sdict = SortedDict((val, -val) for val in range(10000))
    sdict._check()
    assert all(key == -val for key, val in sdict.iteritems())

    sdict.clear()
    sdict._check()
    assert len(sdict) == 0

    sdict = SortedDict.fromkeys(range(1000), None)
    assert all(sdict[key] == None for key in range(1000))
class TransactionRepository:
    def __init__(self):
        self.__accounts = SortedDict()

    def add_amount(self, account, amount):
        account = int(account)
        amount = float(amount)
        self.__accounts[account] = self.__accounts.get(account, 0) + float(amount)

    def get_account_amount(self, account):
        return self.__accounts[int(account)]

    def get_formatted_transactions(self):
        return self.__accounts.iteritems()

    def clear(self):
        self.__accounts.clear()
Exemplo n.º 10
0
class SortedDictBook:
    def __init__(self, depth):
        self.depth = depth
        self.orders = SortedDict()
        self.empty_price = PriceUpdate()

    def update(self, price_update: PriceUpdate) -> bool:
        action = {
            48: self.new_order,
            49: self.update_order,
            50: self.delete_order,
            51: self.delete_thru,
            52: self.delete_from
        }.get(price_update.action, None)

        if action:
            action(price_update)
            return True

        return False

    def update_order(self, price_update: PriceUpdate):
        self.orders[price_update.price] = price_update

    def delete_order(self, price_update: PriceUpdate):
        self.orders.pop(price_update.price)

    def delete_thru(self, price_update: PriceUpdate):
        self.orders.clear()

    def new_order(self, price_update: PriceUpdate):
        if len(self.orders) == self.depth:
            self.orders.popitem()

        self.orders[price_update.price] = price_update

    def get_book(self) -> List[PriceUpdate]:
        return self.orders.values()

    def delete_from(self, price_update: PriceUpdate):
        direction = price_update.level
        del self.orders.iloc[:direction]

    def top(self) -> PriceUpdate:
        return self.orders.peekitem(0)[1] if self.orders else self.empty_price
Exemplo n.º 11
0
class EventQueue:
    def __init__(self):
        self.sortedEvents = SortedDict()
        self.eventTimes = {}

    def pollEvent(self):
        return self.sortedEvents.popitem(index=0)

    def peekEvent(self):
        return self.sortedEvents.peekitem(index=0)

    def remove(self, event):
        time = self.eventTimes.get(event)
        if time != None:
            if time in self.sortedEvents:
                self.sortedEvents.pop(time)
            if event in self.eventTimes:
                self.eventTimes.pop(event)

    def clear(self):
        self.sortedEvents.clear()
        self.eventTimes.clear()

    def add(self, event, time):
        if np.isinf(time):
            return None
        if self.containsTime(time):
            raise ValueError(
                "EventQueue does not support two events at the same time" +
                " " + str(time))

        if isinstance(time, np.ndarray):
            time = np.asarray(time, dtype=np.float)[0]
        self.sortedEvents[time] = event
        self.eventTimes[event] = time

    def containsTime(self, time):
        keys = np.copy(self.sortedEvents.keys())
        keys = np.array(keys)
        result = time in keys
        return result

    def peekTime(self):

        return self.sortedEvents.keys()[0]
class TransactionRepository:
    def __init__(self):
        self.__accounts = SortedDict()

    def add_amount(self, account, amount):
        account = int(account)
        amount = float(amount)
        self.__accounts[account] = self.__accounts.get(account,
                                                       0) + float(amount)

    def get_account_amount(self, account):
        return self.__accounts[int(account)]

    def get_formatted_transactions(self):
        return self.__accounts.iteritems()

    def clear(self):
        self.__accounts.clear()
    def test_SortedDict(self):
        # construct
        sorted_dict = SortedDict({'a': 1, 'c': 2, 'b': 3})
        print('sorted dict is: ', sorted_dict)

        # adding key => value pairs
        sorted_dict['d'] = 3
        print('sorted dict after adding an element: ', sorted_dict)

        # adding element using setdefault()
        sorted_dict.setdefault('e', 4)
        print('sorted dict after setdefault(): ', sorted_dict)

        # using the get function
        print('using the get function to print the value of a: ', sorted_dict.get('a', 0))
        for key in sorted_dict:
            print('{} -> {}'.format(key, sorted_dict[key]), end=' ')
        print()

        # removing all elements from the dict
        sorted_dict.clear()
Exemplo n.º 14
0
class PriorityQueue:
    def __init__(self):
        self.__buckets__ = SortedDict()
        return

    def enqueue(self, priority, value):
        bucket = self.__buckets__.setdefault(priority, [])
        bucket.append(value)
        return

    def dequeue(self):
        if len(self.__buckets__) > 0:
            (priority, bucket) = self.__buckets__.peekitem(0)
            value = bucket[0]
            if len(bucket) > 1:
                self.__buckets__[priority] = bucket[1:]
            else:
                del self.__buckets__[priority]
            return (priority, value)
        else:
            return None

    def clear(self):
        self.__buckets__.clear()
        return

    def size(self):
        total = 0
        for bucket in self.__buckets__.values():
            total += len(bucket)

        return total

    def count(self, priority):
        return len(
            self.__buckets__[priority]) if priority in self.__buckets__ else 0
Exemplo n.º 15
0
class QFeatureMap(QWidget):
    """
    Byte-level map of the memory space.
    """
    def __init__(self, disasm_view, parent=None):
        super().__init__(parent)

        self.disasm_view = disasm_view
        self.workspace = disasm_view.workspace
        self.instance = self.workspace.instance

        self.orientation = Orientation.Vertical

        # widgets
        self.view = None  # type: QFeatureMapView

        # items
        self._insn_indicators = []

        # data instance
        self.addr = ObjectContainer(
            None, name='The current address of the Feature Map.')

        # cached values
        self._addr_to_region = SortedDict()
        self._regionaddr_to_offset = SortedDict()
        self._offset_to_regionaddr = SortedDict()
        self._total_size = None
        self._regions_painted = False

        self._init_widgets()
        self._register_events()

    def sizeHint(self):
        return QSize(25, 25)

    #
    # Public methods
    #

    def refresh(self):

        if self.view is None:
            return

        if not self._regions_painted:
            self._regions_painted = True
            self._paint_regions()

    def select_offset(self, offset):

        addr = self._get_addr_from_pos(offset)
        if addr is None:
            return
        self.addr.am_obj = addr
        self.addr.am_event()

    #
    # Private methods
    #

    def _init_widgets(self):
        self.view = QFeatureMapView(self)

        layout = QHBoxLayout()
        layout.addWidget(self.view)

        layout.setContentsMargins(0, 0, 0, 0)

        self.setLayout(layout)

    def _register_events(self):
        self.disasm_view.infodock.selected_insns.am_subscribe(
            self._paint_insn_indicators)

    def _paint_regions(self):

        cfb = self.instance.cfb_container.am_obj

        if cfb is None:
            return

        # colors
        func_color = Conf.feature_map_color_regular_function
        data_color = Conf.feature_map_color_data
        unknown_color = Conf.feature_map_color_unknown
        delimiter_color = Conf.feature_map_color_delimiter
        if self._total_size is None:
            # calculate the total number of bytes
            b = 0
            self._addr_to_region.clear()
            self._regionaddr_to_offset.clear()
            for mr in cfb.regions:
                self._addr_to_region[mr.addr] = mr
                self._regionaddr_to_offset[mr.addr] = b
                self._offset_to_regionaddr[b] = mr.addr
                b += self._adjust_region_size(mr)
            self._total_size = b

        # iterate through all items and draw the image
        offset = 0
        total_width = self.width()
        current_region = None
        height = self.height()
        print(total_width)
        for addr, obj in cfb.ceiling_items():

            # are we in a new region?
            new_region = False
            if current_region is None or not (
                    current_region.addr <= addr <
                    current_region.addr + current_region.size):
                current_region_addr = next(
                    self._addr_to_region.irange(maximum=addr, reverse=True))
                current_region = self._addr_to_region[current_region_addr]
                new_region = True

            # adjust size
            adjusted_region_size = self._adjust_region_size(current_region)
            adjusted_size = min(
                obj.size, current_region.addr + adjusted_region_size - addr)
            if adjusted_size <= 0:
                continue

            pos = offset * total_width // self._total_size
            length = adjusted_size * total_width // self._total_size
            offset += adjusted_size

            # draw a rectangle
            if isinstance(obj, Unknown):
                pen = QPen(data_color)
                brush = QBrush(data_color)
            elif isinstance(obj, Block):
                # TODO: Check if it belongs to a function or not
                pen = QPen(func_color)
                brush = QBrush(func_color)
            else:
                pen = QPen(unknown_color)
                brush = QBrush(unknown_color)
            rect = QRectF(pos, 0, length, height)
            self.view._scene.addRect(rect, pen, brush)

            # if at the beginning of a new region, draw a line
            if new_region:
                pen = QPen(delimiter_color)
                self.view._scene.addLine(pos, 0, pos, height, pen)

    def _adjust_region_size(self, memory_region):

        if isinstance(memory_region.object,
                      (cle.ExternObject, cle.TLSObject, cle.KernelObject)):
            # Draw unnecessary objects smaller
            return 80
        else:
            print(memory_region.size, memory_region.object)
            return memory_region.size

    def _get_pos_from_addr(self, addr):

        # find the region it belongs to
        try:
            mr_base = next(
                self._addr_to_region.irange(maximum=addr, reverse=True))
        except StopIteration:
            return None

        # get the base offset of that region
        base_offset = self._regionaddr_to_offset[mr_base]

        offset = base_offset + addr - mr_base
        return offset * self.width() // self._total_size

    def _get_addr_from_pos(self, pos):

        offset = int(pos * self._total_size // self.width())

        try:
            base_offset = next(
                self._offset_to_regionaddr.irange(maximum=offset,
                                                  reverse=True))
        except StopIteration:
            return None

        region_addr = self._offset_to_regionaddr[base_offset]
        return region_addr + offset - base_offset

    def _paint_insn_indicators(self):

        scene = self.view.scene()  # type: QGraphicsScene
        for item in self._insn_indicators:
            scene.removeItem(item)
        self._insn_indicators.clear()

        for selected_insn_addr in self.disasm_view.infodock.selected_insns:
            pos = self._get_pos_from_addr(selected_insn_addr)
            if pos is None:
                continue

            pos -= 1  # this is the top-left x coordinate of our arrow body (the rectangle)

            pen = QPen(Qt.yellow)
            brush = QBrush(Qt.yellow)
            rect = QRectF(pos, 0, 2, 5)
            # rectangle
            item = scene.addRect(rect, pen, brush)
            self._insn_indicators.append(item)
            # triangle
            triangle = QPolygonF()
            triangle.append(QPointF(pos - 1, 5))
            triangle.append(QPointF(pos + 3, 5))
            triangle.append(QPointF(pos + 1, 7))
            triangle.append(QPointF(pos - 1, 5))
            item = scene.addPolygon(triangle, pen, brush)
            self._insn_indicators.append(item)
Exemplo n.º 16
0
class DownloadTask(QObject):
    download_ready = Signal(QObject)
    download_not_ready = Signal(QObject)
    download_complete = Signal(QObject)
    download_failed = Signal(QObject)
    download_error = Signal(str)
    download_ok = Signal()

    download_finishing = Signal()
    copy_added = Signal(str)
    chunk_downloaded = Signal(
        str,  # obj_id
        str,  # str(offset) to fix offset >= 2**31
        int)  # length
    chunk_aborted = Signal()
    request_data = Signal(
        str,  # node_id
        str,  # obj_id
        str,  # str(offset) to fix offset >= 2**31
        int)  # length
    abort_data = Signal(
        str,  # node_id
        str,  # obj_id
        str)  # str(offset) to fix offset >= 2**31
    possibly_sync_folder_is_removed = Signal()
    no_disk_space = Signal(
        QObject,  # task
        str,  # display_name
        bool)  # is error
    wrong_hash = Signal(QObject)  # task)
    signal_info_rx = Signal(tuple)

    default_part_size = DOWNLOAD_PART_SIZE
    receive_timeout = 20  # seconds
    retry_limit = 2
    timeouts_limit = 2
    max_node_chunk_requests = 128
    end_race_timeout = 5.  # seconds

    def __init__(self,
                 tracker,
                 connectivity_service,
                 priority,
                 obj_id,
                 obj_size,
                 file_path,
                 display_name,
                 file_hash=None,
                 parent=None,
                 files_info=None):
        QObject.__init__(self, parent=parent)
        self._tracker = tracker
        self._connectivity_service = connectivity_service

        self.priority = priority
        self.size = obj_size
        self.id = obj_id
        self.file_path = file_path
        self.file_hash = file_hash
        self.download_path = file_path + '.download'
        self._info_path = file_path + '.info'
        self.display_name = display_name
        self.received = 0
        self.files_info = files_info

        self.hash_is_wrong = False
        self._ready = False
        self._started = False
        self._paused = False
        self._finished = False
        self._no_disk_space_error = False

        self._wanted_chunks = SortedDict()
        self._downloaded_chunks = SortedDict()
        self._nodes_available_chunks = dict()
        self._nodes_requested_chunks = dict()
        self._nodes_last_receive_time = dict()
        self._nodes_downloaded_chunks_count = dict()
        self._nodes_timeouts_count = dict()
        self._total_chunks_count = 0

        self._file = None
        self._info_file = None

        self._started_time = time()

        self._took_from_turn = 0
        self._received_via_turn = 0
        self._received_via_p2p = 0

        self._retry = 0

        self._limiter = None

        self._init_wanted_chunks()

        self._on_downloaded_cb = None
        self._on_failed_cb = None
        self.download_complete.connect(self._on_downloaded)
        self.download_failed.connect(self._on_failed)

        self._timeout_timer = QTimer(self)
        self._timeout_timer.setInterval(15 * 1000)
        self._timeout_timer.setSingleShot(False)
        self._timeout_timer.timeout.connect(self._on_check_timeouts)

        self._leaky_timer = QTimer(self)
        self._leaky_timer.setInterval(1000)
        self._leaky_timer.setSingleShot(True)
        self._leaky_timer.timeout.connect(self._download_chunks)

        self._network_limited_error_set = False

    def __lt__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__lt__(self, other)

        if self == other:
            return False

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id < other.id

            return self.size - self.received < other.size - other.received

        return self.priority > other.priority

    def __le__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__le__(self, other)

        if self == other:
            return True

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id < other.id

            return self.size - self.received < other.size - other.received

        return self.priority >= other.priority

    def __gt__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__gt__(self, other)

        if self == other:
            return False

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id > other.id

            return self.size - self.received > other.size - other.received

        return self.priority <= other.priority

    def __ge__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__ge__(self, other)

        if self == other:
            return True

        if self.priority == other.priority:
            if self.size - self.received == other.size - other.received:
                return self.id > other.id

            return self.size - self.received > other.size - other.received

        return self.priority <= other.priority

    def __eq__(self, other):
        if not isinstance(other, DownloadTask):
            return object.__eq__(self, other)

        return self.id == other.id

    def on_availability_info_received(self, node_id, obj_id, info):
        if obj_id != self.id or self._finished or not info:
            return

        logger.info(
            "availability info received, "
            "node_id: %s, obj_id: %s, info: %s", node_id, obj_id, info)

        new_chunks_stored = self._store_availability_info(node_id, info)
        if not self._ready and new_chunks_stored:
            if self._check_can_receive(node_id):
                self._ready = True
                self.download_ready.emit(self)
            else:
                self.download_error.emit('Turn limit reached')

        if self._started and not self._paused \
                and not self._nodes_requested_chunks.get(node_id, None):
            logger.debug("Downloading next chunk")
            self._download_next_chunks(node_id)
            self._clean_nodes_last_receive_time()
            self._check_download_not_ready(self._nodes_requested_chunks)

    def on_availability_info_failure(self, node_id, obj_id, error):
        if obj_id != self.id or self._finished:
            return

        logger.info(
            "availability info failure, "
            "node_id: %s, obj_id: %s, error: %s", node_id, obj_id, error)
        try:
            if error["err_code"] == "FILE_CHANGED":
                self.download_failed.emit(self)
        except Exception as e:
            logger.warning("Can't parse error message. Reson: %s", e)

    def start(self, limiter):
        if exists(self.file_path):
            logger.info("download task file already downloaded %s",
                        self.file_path)
            self.received = self.size
            self.download_finishing.emit()
            self.download_complete.emit(self)
            return

        self._limiter = limiter

        if self._started:
            # if we swapped task earlier
            self.resume()
            return

        self._no_disk_space_error = False
        if not self.check_disk_space():
            return

        logger.info("starting download task, obj_id: %s", self.id)
        self._started = True
        self._paused = False
        self.hash_is_wrong = False
        self._started_time = time()
        self._send_start_statistic()
        if not self._open_file():
            return

        self._read_info_file()

        for downloaded_chunk in self._downloaded_chunks.items():
            self._remove_from_chunks(downloaded_chunk[0], downloaded_chunk[1],
                                     self._wanted_chunks)

        self.received = sum(self._downloaded_chunks.values())
        if self._complete_download():
            return

        self._download_chunks()
        if not self._timeout_timer.isActive():
            self._timeout_timer.start()

    def check_disk_space(self):
        if self.size * 2 + get_signature_file_size(self.size) > \
                get_free_space_by_filepath(self.file_path):
            self._emit_no_disk_space()
            return False

        return True

    def pause(self, disconnect_cb=True):
        self._paused = True
        if disconnect_cb:
            self.disconnect_callbacks()
        self.stop_download_chunks()

    def resume(self, start_download=True):
        self._started_time = time()
        self._paused = False
        self.hash_is_wrong = False
        if start_download:
            self._started = True
            self._download_chunks()
            if not self._timeout_timer.isActive():
                self._timeout_timer.start()

    def cancel(self):
        self._close_file()
        self._close_info_file()
        self.stop_download_chunks()

        self._finished = True

    def clean(self):
        logger.debug("Cleaning download files %s", self.download_path)
        try:
            remove_file(self.download_path)
        except:
            pass
        try:
            remove_file(self._info_path)
        except:
            pass

    def connect_callbacks(self, on_downloaded, on_failed):
        self._on_downloaded_cb = on_downloaded
        self._on_failed_cb = on_failed

    def disconnect_callbacks(self):
        self._on_downloaded_cb = None
        self._on_failed_cb = None

    @property
    def ready(self):
        return self._ready

    @property
    def paused(self):
        return self._paused

    @property
    def no_disk_space_error(self):
        return self._no_disk_space_error

    def _init_wanted_chunks(self):
        self._total_chunks_count = math.ceil(
            float(self.size) / float(DOWNLOAD_CHUNK_SIZE))

        self._wanted_chunks[0] = self.size

    def _on_downloaded(self, task):
        if callable(self._on_downloaded_cb):
            self._on_downloaded_cb(task)
            self._on_downloaded_cb = None

    def _on_failed(self, task):
        if callable(self._on_failed_cb):
            self._on_failed_cb(task)
            self._on_failed_cb = None

    def on_data_received(self, node_id, obj_id, offset, length, data):
        if obj_id != self.id or self._finished:
            return

        logger.debug(
            "on_data_received for objId: %s, offset: %s, from node_id: %s",
            self.id, offset, node_id)

        now = time()
        last_received_time = self._nodes_last_receive_time.get(node_id, 0.)
        if node_id in self._nodes_last_receive_time:
            self._nodes_last_receive_time[node_id] = now

        self._nodes_timeouts_count.pop(node_id, 0)

        downloaded_count = \
            self._nodes_downloaded_chunks_count.get(node_id, 0) + 1
        self._nodes_downloaded_chunks_count[node_id] = downloaded_count

        # to collect traffic info
        node_type = self._connectivity_service.get_self_node_type()
        is_share = node_type == "webshare"
        # tuple -> (obj_id, rx_wd, rx_wr, is_share)
        if self._connectivity_service.is_relayed(node_id):
            # relayed traffic
            info_rx = (obj_id, 0, length, is_share)
        else:
            # p2p traffic
            info_rx = (obj_id, length, 0, is_share)
        self.signal_info_rx.emit(info_rx)

        if not self._is_chunk_already_downloaded(offset):
            if not self._on_new_chunk_downloaded(node_id, offset, length,
                                                 data):
                return

        else:
            logger.debug("chunk %s already downloaded", offset)

        requested_chunks = self._nodes_requested_chunks.get(
            node_id, SortedDict())
        if not requested_chunks:
            return

        self._remove_from_chunks(offset, length, requested_chunks)

        if not requested_chunks:
            self._nodes_requested_chunks.pop(node_id, None)

        requested_count = sum(requested_chunks.values()) // DOWNLOAD_CHUNK_SIZE
        if downloaded_count * 4 >= requested_count \
                and requested_count < self.max_node_chunk_requests:
            self._download_next_chunks(node_id, now - last_received_time)
            self._clean_nodes_last_receive_time()
            self._check_download_not_ready(self._nodes_requested_chunks)

    def _is_chunk_already_downloaded(self, offset):
        if self._downloaded_chunks:
            chunk_index = self._downloaded_chunks.bisect_right(offset)
            if chunk_index > 0:
                chunk_index -= 1

                chunk = self._downloaded_chunks.peekitem(chunk_index)
                if offset < chunk[0] + chunk[1]:
                    return True

        return False

    def _on_new_chunk_downloaded(self, node_id, offset, length, data):
        if not self._write_to_file(offset, data):
            return False

        self.received += length
        if self._connectivity_service.is_relayed(node_id):
            self._received_via_turn += length
        else:
            self._received_via_p2p += length

        new_offset = offset
        new_length = length

        left_index = self._downloaded_chunks.bisect_right(new_offset)
        if left_index > 0:
            left_chunk = self._downloaded_chunks.peekitem(left_index - 1)
            if left_chunk[0] + left_chunk[1] == new_offset:
                new_offset = left_chunk[0]
                new_length += left_chunk[1]
                self._downloaded_chunks.popitem(left_index - 1)

        right_index = self._downloaded_chunks.bisect_right(new_offset +
                                                           new_length)
        if right_index > 0:
            right_chunk = self._downloaded_chunks.peekitem(right_index - 1)
            if right_chunk[0] == new_offset + new_length:
                new_length += right_chunk[1]
                self._downloaded_chunks.popitem(right_index - 1)

        self._downloaded_chunks[new_offset] = new_length

        assert self._remove_from_chunks(offset, length, self._wanted_chunks)

        logger.debug("new chunk downloaded from node: %s, wanted size: %s",
                     node_id, sum(self._wanted_chunks.values()))

        part_offset = (offset / DOWNLOAD_PART_SIZE) * DOWNLOAD_PART_SIZE
        part_size = min([DOWNLOAD_PART_SIZE, self.size - part_offset])
        if new_offset <= part_offset \
                and new_offset + new_length >= part_offset + part_size:
            if self._file:
                self._file.flush()
            self._write_info_file()

            self.chunk_downloaded.emit(self.id, str(part_offset), part_size)

        if self._complete_download():
            return False

        return True

    def _remove_from_chunks(self, offset, length, chunks):
        if not chunks:
            return False

        chunk_left_index = chunks.bisect_right(offset)
        if chunk_left_index > 0:
            left_chunk = chunks.peekitem(chunk_left_index - 1)
            if offset >= left_chunk[0] + left_chunk[1] \
                    and len(chunks) > chunk_left_index:
                left_chunk = chunks.peekitem(chunk_left_index)
            else:
                chunk_left_index -= 1
        else:
            left_chunk = chunks.peekitem(chunk_left_index)

        if offset >= left_chunk[0] + left_chunk[1] or \
                offset + length <= left_chunk[0]:
            return False

        chunk_right_index = chunks.bisect_right(offset + length)
        right_chunk = chunks.peekitem(chunk_right_index - 1)

        if chunk_right_index == chunk_left_index:
            to_del = [right_chunk[0]]
        else:
            to_del = list(chunks.islice(chunk_left_index, chunk_right_index))

        for chunk in to_del:
            chunks.pop(chunk)

        if left_chunk[0] < offset:
            if left_chunk[0] + left_chunk[1] >= offset:
                chunks[left_chunk[0]] = offset - left_chunk[0]

        if right_chunk[0] + right_chunk[1] > offset + length:
            chunks[offset + length] = \
                right_chunk[0] + right_chunk[1] - offset - length
        return True

    def on_data_failed(self, node_id, obj_id, offset, error):
        if obj_id != self.id or self._finished:
            return

        logger.info(
            "data request failure, "
            "node_id: %s, obj_id: %s, offset: %s, error: %s", node_id, obj_id,
            offset, error)

        self.on_node_disconnected(node_id)

    def get_downloaded_chunks(self):
        if not self._downloaded_chunks:
            return None

        return self._downloaded_chunks

    def on_node_disconnected(self,
                             node_id,
                             connection_alive=False,
                             timeout_limit_exceed=True):
        requested_chunks = self._nodes_requested_chunks.pop(node_id, None)
        logger.info("node disconnected %s, chunks removed from requested: %s",
                    node_id, requested_chunks)
        if timeout_limit_exceed:
            self._nodes_available_chunks.pop(node_id, None)
            self._nodes_timeouts_count.pop(node_id, None)
            if connection_alive:
                self._connectivity_service.reconnect(node_id)
        self._nodes_last_receive_time.pop(node_id, None)
        self._nodes_downloaded_chunks_count.pop(node_id, None)

        if connection_alive:
            self.abort_data.emit(node_id, self.id, None)

        if self._nodes_available_chunks:
            self._download_chunks(check_node_busy=True)
        else:
            chunks_to_test = self._nodes_requested_chunks \
                if self._started and not self._paused \
                else self._nodes_available_chunks
            self._check_download_not_ready(chunks_to_test)

    def complete(self):
        if self._started and not self._finished:
            self._complete_download(force_complete=True)
        elif not self._finished:
            self._finished = True
            self.clean()
            self.download_complete.emit(self)

    def _download_chunks(self, check_node_busy=False):
        if not self._started or self._paused or self._finished:
            return

        logger.debug("download_chunks for %s", self.id)

        node_ids = list(self._nodes_available_chunks.keys())
        random.shuffle(node_ids)
        for node_id in node_ids:
            node_free = not check_node_busy or \
                        not self._nodes_requested_chunks.get(node_id, None)
            if node_free:
                self._download_next_chunks(node_id)
        self._clean_nodes_last_receive_time()
        self._check_download_not_ready(self._nodes_requested_chunks)

    def _check_can_receive(self, node_id):
        return True

    def _write_to_file(self, offset, data):
        self._file.seek(offset)
        try:
            self._file.write(data)
        except EnvironmentError as e:
            logger.error("Download task %s can't write to file. Reason: %s",
                         self.id, e)
            self._send_error_statistic()
            if e.errno == errno.ENOSPC:
                self._emit_no_disk_space(error=True)
            else:
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
            return False

        return True

    def _open_file(self, clean=False):
        if not self._file or self._file.closed:
            try:
                if clean:
                    self._file = open(self.download_path, 'wb')
                else:
                    self._file = open(self.download_path, 'r+b')
            except IOError:
                try:
                    self._file = open(self.download_path, 'wb')
                except IOError as e:
                    logger.error(
                        "Can't open file for download for task %s. "
                        "Reason: %s", self.id, e)
                    self.download_failed.emit(self)
                    return False

        return True

    def _close_file(self):
        if not self._file:
            return True

        try:
            self._file.close()
        except EnvironmentError as e:
            logger.error("Download task %s can't close file. Reason: %s",
                         self.id, e)
            self._send_error_statistic()
            if e.errno == errno.ENOSPC:
                self._emit_no_disk_space(error=True)
            else:
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
            self._file = None
            return False

        self._file = None
        return True

    def _write_info_file(self):
        try:
            self._info_file.seek(0)
            self._info_file.truncate()
            pickle.dump(self._downloaded_chunks, self._info_file,
                        pickle.HIGHEST_PROTOCOL)
            self._info_file.flush()
        except EnvironmentError as e:
            logger.debug("Can't write to info file for task id %s. Reason: %s",
                         self.id, e)

    def _read_info_file(self):
        try:
            if not self._info_file or self._info_file.closed:
                self._info_file = open(self._info_path, 'a+b')
                self._info_file.seek(0)
            try:
                self._downloaded_chunks = pickle.load(self._info_file)
            except:
                pass
        except EnvironmentError as e:
            logger.debug("Can't open info file for task id %s. Reason: %s",
                         self.id, e)

    def _close_info_file(self, to_remove=False):
        if not self._info_file:
            return

        try:
            self._info_file.close()
            if to_remove:
                remove_file(self._info_path)
        except Exception as e:
            logger.debug(
                "Can't close or remove info file "
                "for task id %s. Reason: %s", self.id, e)
        self._info_file = None

    def _complete_download(self, force_complete=False):
        if (not self._wanted_chunks or force_complete) and \
                not self._finished:
            logger.debug("download %s completed", self.id)
            self._nodes_requested_chunks.clear()
            for node_id in self._nodes_last_receive_time.keys():
                self.abort_data.emit(node_id, self.id, None)

            if not force_complete:
                self.download_finishing.emit()

            if not force_complete and self.file_hash:
                hash_check_result = self._check_file_hash()
                if hash_check_result is not None:
                    return hash_check_result

            self._started = False
            self._finished = True
            self.stop_download_chunks()
            self._close_info_file(to_remove=True)
            if not self._close_file():
                return False

            try:
                if force_complete:
                    remove_file(self.download_path)
                    self.download_complete.emit(self)
                else:
                    shutil.move(self.download_path, self.file_path)
                    self._send_end_statistic()
                    self.download_complete.emit(self)
                    if self.file_hash:
                        self.copy_added.emit(self.file_hash)
            except EnvironmentError as e:
                logger.error(
                    "Download task %s can't (re)move file. "
                    "Reason: %s", self.id, e)
                self._send_error_statistic()
                self.download_failed.emit(self)
                self.possibly_sync_folder_is_removed.emit()
                return False

            result = True
        else:
            result = not self._wanted_chunks
        return result

    def _check_file_hash(self):
        self._file.flush()
        try:
            hash = Rsync.hash_from_block_checksum(
                Rsync.block_checksum(self.download_path))
        except IOError as e:
            logger.error("download %s error: %s", self.id, e)
            hash = None
        if hash != self.file_hash:
            logger.error(
                "download hash check failed objId: %s, "
                "expected hash: %s, actual hash: %s", self.id, self.file_hash,
                hash)
            if not self._close_file() or not self._open_file(clean=True):
                return False

            self._downloaded_chunks.clear()
            self._nodes_downloaded_chunks_count.clear()
            self._nodes_last_receive_time.clear()
            self._nodes_timeouts_count.clear()
            self._write_info_file()
            self._init_wanted_chunks()

            self.received = 0
            if self._retry < self.retry_limit:
                self._retry += 1
                self.resume()
            else:
                self._retry = 0
                self._nodes_available_chunks.clear()
                self.hash_is_wrong = True
                self.wrong_hash.emit(self)
            return True

        return None

    def _download_next_chunks(self, node_id, time_from_last_received_chunk=0.):
        if (self._paused or not self._started or not self._ready
                or self._finished or not self._wanted_chunks
                or self._leaky_timer.isActive()):
            return

        total_requested = sum(
            map(lambda x: sum(x.values()),
                self._nodes_requested_chunks.values()))

        if total_requested + self.received >= self.size:
            if self._nodes_requested_chunks.get(node_id, None) and \
                    time_from_last_received_chunk <= self.end_race_timeout:
                return

            available_chunks = \
                self._get_end_race_chunks_to_download_from_node(node_id)
        else:
            available_chunks = \
                self._get_available_chunks_to_download_from_node(node_id)

        if not available_chunks:
            logger.debug("no chunks available for download %s", self.id)
            logger.debug("downloading from: %s nodes, length: %s, wanted: %s",
                         len(self._nodes_requested_chunks), total_requested,
                         self.size - self.received)
            return

        available_offset = random.sample(available_chunks.keys(), 1)[0]
        available_length = available_chunks[available_offset]
        logger.debug("selected random offset: %s", available_offset)

        parts_count = math.ceil(
            float(available_length) / float(DOWNLOAD_PART_SIZE)) - 1
        logger.debug("parts count: %s", parts_count)

        part_to_download_number = random.randint(0, parts_count)
        offset = available_offset + \
                 part_to_download_number * DOWNLOAD_PART_SIZE
        length = min(DOWNLOAD_PART_SIZE,
                     available_offset + available_length - offset)
        logger.debug("selected random part: %s, offset: %s, length: %s",
                     part_to_download_number, offset, length)

        self._request_data(node_id, offset, length)

    def _get_end_race_chunks_to_download_from_node(self, node_id):
        available_chunks = self._nodes_available_chunks.get(node_id, None)
        if not available_chunks:
            return []

        available_chunks = available_chunks.copy()
        logger.debug("end race downloaded_chunks: %s", self._downloaded_chunks)
        logger.debug("end race requested_chunks: %s",
                     self._nodes_requested_chunks)
        logger.debug("end race available_chunks before excludes: %s",
                     available_chunks)
        if self._downloaded_chunks:
            for downloaded_chunk in self._downloaded_chunks.items():
                self._remove_from_chunks(downloaded_chunk[0],
                                         downloaded_chunk[1], available_chunks)
        if not available_chunks:
            return []

        available_from_other_nodes = available_chunks.copy()
        for requested_offset, requested_length in \
                self._nodes_requested_chunks.get(node_id, dict()).items():
            self._remove_from_chunks(requested_offset, requested_length,
                                     available_from_other_nodes)

        result = available_from_other_nodes if available_from_other_nodes \
            else available_chunks

        if result:
            logger.debug("end race available_chunks after excludes: %s",
                         available_chunks)
        return result

    def _get_available_chunks_to_download_from_node(self, node_id):
        available_chunks = self._nodes_available_chunks.get(node_id, None)
        if not available_chunks:
            return []

        available_chunks = available_chunks.copy()
        logger.debug("downloaded_chunks: %s", self._downloaded_chunks)
        logger.debug("requested_chunks: %s", self._nodes_requested_chunks)
        logger.debug("available_chunks before excludes: %s", available_chunks)
        for _, requested_chunks in self._nodes_requested_chunks.items():
            for requested_offset, requested_length in requested_chunks.items():
                self._remove_from_chunks(requested_offset, requested_length,
                                         available_chunks)
        if not available_chunks:
            return []

        for downloaded_chunk in self._downloaded_chunks.items():
            self._remove_from_chunks(downloaded_chunk[0], downloaded_chunk[1],
                                     available_chunks)
        logger.debug("available_chunks after excludes: %s", available_chunks)
        return available_chunks

    def _request_data(self, node_id, offset, length):
        logger.debug("Requesting date from node %s, request_chunk (%s, %s)",
                     node_id, offset, length)
        if self._limiter:
            try:
                self._limiter.leak(length)
            except LeakyBucketException:
                if node_id not in self._nodes_requested_chunks:
                    self._nodes_last_receive_time.pop(node_id, None)
                    if not self._network_limited_error_set:
                        self.download_error.emit('Network limited.')
                        self._network_limited_error_set = True
                if not self._leaky_timer.isActive():
                    self._leaky_timer.start()
                return

        if self._network_limited_error_set:
            self._network_limited_error_set = False
            self.download_ok.emit()

        requested_chunks = self._nodes_requested_chunks.get(node_id, None)
        if not requested_chunks:
            requested_chunks = SortedDict()
            self._nodes_requested_chunks[node_id] = requested_chunks
        requested_chunks[offset] = length
        logger.debug("Requested chunks %s", requested_chunks)
        self._nodes_last_receive_time[node_id] = time()
        self.request_data.emit(node_id, self.id, str(offset), length)

    def _clean_nodes_last_receive_time(self):
        for node_id in list(self._nodes_last_receive_time.keys()):
            if node_id not in self._nodes_requested_chunks:
                self._nodes_last_receive_time.pop(node_id, None)

    def _on_check_timeouts(self):
        if self._paused or not self._started \
                or self._finished or self._leaky_timer.isActive():
            return

        timed_out_nodes = set()
        cur_time = time()
        logger.debug("Chunk requests check %s",
                     len(self._nodes_requested_chunks))
        if self._check_download_not_ready(self._nodes_requested_chunks):
            return

        for node_id in self._nodes_last_receive_time:
            last_receive_time = self._nodes_last_receive_time.get(node_id)
            if cur_time - last_receive_time > self.receive_timeout:
                timed_out_nodes.add(node_id)

        logger.debug("Timed out nodes %s, nodes last receive time %s",
                     timed_out_nodes, self._nodes_last_receive_time)
        for node_id in timed_out_nodes:
            timeout_count = self._nodes_timeouts_count.pop(node_id, 0)
            timeout_count += 1
            if timeout_count >= self.timeouts_limit:
                retry = False
            else:
                retry = True
                self._nodes_timeouts_count[node_id] = timeout_count
            logger.debug("Node if %s, timeout_count %s, retry %s", node_id,
                         timeout_count, retry)
            self.on_node_disconnected(node_id,
                                      connection_alive=True,
                                      timeout_limit_exceed=not retry)

    def _get_chunks_from_info(self, chunks, info):
        new_added = False
        for part_info in info:
            logger.debug("get_chunks_from_info part_info %s", part_info)
            if part_info.length == 0:
                continue

            if not chunks:
                chunks[part_info.offset] = part_info.length
                new_added = True
                continue

            result_offset = part_info.offset
            result_length = part_info.length
            left_index = chunks.bisect_right(part_info.offset)
            if left_index > 0:
                left_chunk = chunks.peekitem(left_index - 1)
                if (left_chunk[0] <= part_info.offset
                        and left_chunk[0] + left_chunk[1] >=
                        part_info.offset + part_info.length):
                    continue

                if part_info.offset <= left_chunk[0] + left_chunk[1]:
                    result_offset = left_chunk[0]
                    result_length = part_info.offset + \
                                    part_info.length - result_offset
                    left_index -= 1

            right_index = chunks.bisect_right(part_info.offset +
                                              part_info.length)
            if right_index > 0:
                right_chunk = chunks.peekitem(right_index - 1)
                if part_info.offset + part_info.length <= \
                        right_chunk[0] + right_chunk[1]:
                    result_length = right_chunk[0] + \
                                    right_chunk[1] - result_offset

            to_delete = list(chunks.islice(left_index, right_index))

            for to_del in to_delete:
                chunks.pop(to_del)

            new_added = True
            chunks[result_offset] = result_length

        return new_added

    def _store_availability_info(self, node_id, info):
        known_chunks = self._nodes_available_chunks.get(node_id, None)
        if not known_chunks:
            known_chunks = SortedDict()
            self._nodes_available_chunks[node_id] = known_chunks
        return self._get_chunks_from_info(known_chunks, info)

    def _check_download_not_ready(self, checkable):
        if not self._wanted_chunks and self._started:
            self._complete_download(force_complete=False)
            return False

        if self._leaky_timer.isActive():
            if not self._nodes_available_chunks:
                self._make_not_ready()
                return True

        elif not checkable:
            self._make_not_ready()
            return True

        return False

    def _make_not_ready(self):
        if not self._ready:
            return

        logger.info("download %s not ready now", self.id)
        self._ready = False
        self._started = False
        if self._timeout_timer.isActive():
            self._timeout_timer.stop()
        if self._leaky_timer.isActive():
            self._leaky_timer.stop()
        self.download_not_ready.emit(self)

    def _clear_globals(self):
        self._wanted_chunks.clear()
        self._downloaded_chunks.clear()
        self._nodes_available_chunks.clear()
        self._nodes_requested_chunks.clear()
        self._nodes_last_receive_time.clear()
        self._nodes_downloaded_chunks_count.clear()
        self._nodes_timeouts_count.clear()
        self._total_chunks_count = 0

    def stop_download_chunks(self):
        if self._leaky_timer.isActive():
            self._leaky_timer.stop()
        if self._timeout_timer.isActive():
            self._timeout_timer.stop()

        for node_id in self._nodes_requested_chunks:
            self.abort_data.emit(node_id, self.id, None)

        self._nodes_requested_chunks.clear()
        self._nodes_last_receive_time.clear()

    def _emit_no_disk_space(self, error=False):
        self._no_disk_space_error = True
        self._nodes_available_chunks.clear()
        self._clear_globals()
        self._make_not_ready()
        file_name = self.display_name.split()[-1] \
            if self.display_name else ""
        self.no_disk_space.emit(self, file_name, error)

    def _send_start_statistic(self):
        if self._tracker:
            self._tracker.download_start(self.id, self.size)

    def _send_end_statistic(self):
        if self._tracker:
            time_diff = time() - self._started_time
            if time_diff < 1e-3:
                time_diff = 1e-3

            self._tracker.download_end(
                self.id,
                time_diff,
                websockets_bytes=0,
                webrtc_direct_bytes=self._received_via_p2p,
                webrtc_relay_bytes=self._received_via_turn,
                chunks=len(self._downloaded_chunks),
                chunks_reloaded=0,
                nodes=len(self._nodes_available_chunks))

    def _send_error_statistic(self):
        if self._tracker:
            time_diff = time() - self._started_time
            if time_diff < 1e-3:
                time_diff = 1e-3

            self._tracker.download_error(
                self.id,
                time_diff,
                websockets_bytes=0,
                webrtc_direct_bytes=self._received_via_p2p,
                webrtc_relay_bytes=self._received_via_turn,
                chunks=len(self._downloaded_chunks),
                chunks_reloaded=0,
                nodes=len(self._nodes_available_chunks))
Exemplo n.º 17
0
class ValueCounts(object):
    """A dictionary of value counts

    The dictionary of value counts comes out of pandas.series.value_counts()
    for one variable or pandas.Dataframe.groupby.size() performed over one
    or multiple variables.
    """

    def __init__(self, key, subkey=None, counts={}, sel={}):
        """Initialize ValueCounts instance

        :param list key: key is a tuple, list or string of (the) variable name(s), matching those and the structure of
               the keys in the value_counts dictionary.
        :param dict counts: the value_counts dictionary.
        :param list subkey: subset of key. If provided, the value_counts dictionary will be projected from key onto the
               (subset of) subkey. E.g. use this to map a two dimensional value_counts dictionary onto one specified
               dimension. Default is None. Optional.
        :param dict sel: Apply selections to value_counts dictionary. Default is {}. Optional.
        """

        key = self._transform_key(key)
        subkey = self._transform_key(subkey) if subkey is not None else key
        counts = dict((k if isinstance(k, tuple) else (k,), v) for k, v in counts.items())

        self._key = key
        self._skey = subkey if subkey is not None else key
        self._cnts = counts
        self._sel = dict((k, list(s) if hasattr(s, '__iter__') else [s]) for k, s in sel.items())
        self._ktos = tuple(key.index(k) for k in subkey)
        self._kind = dict((k, key.index(k)) for k in key)
        self._stok = tuple(subkey.index(k) if k in subkey else None for k in key)
        self._no_none_cnts = SortedDict()

    def __lt__(self, other):
        """Less than operator

        :param object other: the other ValueCounts object
        :returns: true or false
        :rtype: bool
        """

        return len(self._key) < len(other._key)

    def __gt__(self, other):
        """Greater than operator

        :param object other: the other ValueCounts object
        :returns: true or false
        :rtype: bool
        """

        return len(self._key) > len(other._key)

    def __eq__(self, other):
        """Equal to operator

        :param object other: the other ValueCounts object
        :returns: true or false
        :rtype: bool
        """

        return len(self._key) == len(other._key) and tuple(self._skey) == tuple(other._skey)

    def __le__(self, other):
        """Less than or equal to operator

        :param object other: the other ValueCounts object
        :returns: true or false
        :rtype: bool
        """

        return self.__lt__(other) or self.__eq__(other)

    def __ge__(self, other):
        """Greater than or equal to operator

        :param object other: the other ValueCounts object
        :returns: true or false
        :rtype: bool
        """

        return self.__lt__(other) or self.__eq__(other)

    def _transform_key(self, key):
        """Transform input key to desired tuple format

        Input key, a tuple, list, or string, gets transformed into the
        desired key format.  Desired key format is a tuple like this:

        * ('foo',) : for one variable (note the comma),
        * ('foo', 'bar') : for multiple variables.

        This format follows the same structure of the keys used in the
        internal value_counts dictionary.

        :param tuple key: input key, is a tuple, list, or string
        :returns: the tuplelized key
        :rtype: tuple
        """

        assert key, 'Input key contains no variable name(s). Expect str or tuple or list of strings.'

        has_itr = isinstance(key, list) or isinstance(key, tuple)
        if has_itr:
            if len(key) == 1:
                key = (key[0],)
            else:
                key = tuple(key)
        elif isinstance(key, str):
            key = (key,)
        else:
            # don't recognize the key! pass.
            pass
        return key

    @property
    def counts(self):
        """Value-counts dictionary

        :returns: after processing, returns the value_counts dictionary
        :rtype: dict
        """

        self.process_counts()
        return self._cnts

    @property
    def nononecounts(self):
        """Value-counts dictionary without None keys

        :returns: after processing, returns the value_counts dictionary without None keys
        :rtype: dict
        """

        self.process_counts()
        if len(self._no_none_cnts) == 0:
            self._no_none_cnts = SortedDict([(var, cnt) for var, cnt in self._cnts.items() if None not in var])
        # return dict([(var, cnt) for var, cnt in self._cnts.items() if None
        # not in var])
        return self._no_none_cnts

    @property
    def key(self):
        """Current value-counts key

        :returns: the key
        :rtype: tuple
        """

        self.process_counts()
        return self._key

    @property
    def skey(self):
        """Current value-counts subkey

        :returns: the subkey
        :rtype: tuple
        """

        return self._skey

    @property
    def num_bins(self):
        """Number of value-counts bins

        :returns: number of bins
        :rtype: int
        """

        return len(self.counts)

    @property
    def num_nonone_bins(self):
        """Number of not-none value-counts bins

        :returns: number of not-none bins
        :rtype: int
        """

        return len(self.nononecounts)

    @property
    def sum_counts(self):
        """Sum of counts of all value-counts bins

        :returns: the sum of counts of all bins
        :rtype: float
        """

        return sum(self._cnts.values())

    @property
    def sum_nonone_counts(self):
        """Sum of not-none counts of all value-counts bins

        :returns: the sum of not-none counts of all bins
        :rtype: float
        """

        return sum(self.nononecounts.values())

    def create_sub_counts(self, subkey, sel={}):
        """Project existing value counts onto a subset of keys

        E.g. map variables x,y onto single dimension x, so for each bin in x integrate over y.

        :param tuple subkey: input sub-key, is a tuple, list, or string.
                             This is the new key of variables for the returned ValueCounts object.
        :param dict sel: dictionary with selection. Default is {}.
        :returns: value_counts object where subkey has become the new key.
        :rtype: ValueCounts
        """

        subkey = self._transform_key(subkey)
        return ValueCounts(self.key, subkey, self.counts, sel)

    def count(self, value_bin):
        """Get bin count for specific bin-key value bin

        :param tuple value_bin: a specific key, and can be a list or tuple.
        :returns: specific bin counter value
        :rtype: int
        """

        self.process_counts()
        return self._cnts.get(tuple(value_bin[k] for k in self._stok), 0)

    def get_values(self, val_keys=()):
        """Get all key-values of a subset of keys

        E.g. give all x values in of the keys, when the value_counts object has keys (x, y).

        :param tuple value_keys: a specific sub-key to get key values for.
        :returns: all key-values of a subset of keys.
        :rtype: tuple
        """

        self.process_counts()
        if not val_keys:
            val_keys = self._skey
        return sorted(set(tuple(key[self._kind[k]] for k in val_keys) for key in self._cnts.keys()))

    def remove_keys_of_inconsistent_type(self, prefered_key_type=None):
        """Remove keys with inconsistent data type(s)

        :param tuple prefered_key_type: the prefered key type to keep. Can be a
                                        tuple, list, or single type.  E.g. str
                                        or (int, str, float).  If None provided,
                                        the most common key type found is kept.
        """

        self.process_counts()

        # NB: np.dtype(type(k).type : gives back a consistent numpy type for
        # all strings, floats, ints, etc.

        # convert prefered_key_type to right format
        if prefered_key_type is not None:
            #has_itr = hasattr(prefered_key_type, '__iter__')
            has_itr = isinstance(prefered_key_type, list) or isinstance(prefered_key_type, tuple)

            if has_itr:
                if len(prefered_key_type) == 1:
                    prefered_key_type = (prefered_key_type[0],)
                else:
                    prefered_key_type = tuple(prefered_key_type)
            else:
                prefered_key_type = (prefered_key_type,)
            # turn into consistent types, used for comparison below
            prefered_key_type = tuple(np.dtype(k).type for k in prefered_key_type)

        # sort all keys by their key type, and count how often these types
        # occur
        cnts_types = Counter()
        cnts_keys = dict((tuple(np.dtype(type(k)).type for k in key), [])
                         for key in self._cnts)
        for key in self._cnts:
            ktype = tuple(np.dtype(type(k)).type for k in key)
            cnts_types[ktype] += self._cnts[key]
            cnts_keys[ktype].append(key)

        # pick the prefered key type to keep
        if prefered_key_type is None:
            # select most common key type
            prefered_key_type = cnts_types.most_common()[0][0]

        # remove all keys of different key type than preferred
        for ktype in cnts_types:
            if ktype == prefered_key_type:
                continue
            keys = cnts_keys[ktype]
            for k in keys:
                del self._cnts[k]

        # no_none_cnts gets refilled next time when called
        self._no_none_cnts.clear()

    def process_counts(self, accept_equiv=True):
        """Project value counts onto the existing subset of keys

        E.g. map variables x,y onto single dimension x, so for each bin in x integrate over y.

        :param bool accept_equiv: accept equivalence of key and subkey if if
                                  subkey is in different order than key. Default
                                  is true.
        :returns: successful projection or not
        :rtype: bool
        """

        # only process if counts need processing
        if not self._sel and self._key == self._skey:
            return False
        if not self._sel and accept_equiv and all(k in self._skey for k in self._key):
            return False

        # create new counts dictionary with subcounts
        scnts = {}
        for vals, cnt in self._cnts.items():
            # apply selection
            if self._sel and any(vals[self._kind[k]] not in s for k, s in self._sel.items()):
                continue

            # add counts for subkey to sum
            vkey = tuple(vals[i] for i in self._ktos)
            if vkey not in scnts:
                scnts[vkey] = 0
            scnts[vkey] += cnt

        # set subcounts as new counts
        self._key = self._skey
        self._cnts = scnts
        self._sel = {}
        self._kind = dict((k, self._key.index(k)) for k in self._key)
        self._ktos = self._stok = tuple(range(len(self._skey)))
        # no_none_cnts refilled when called
        self._no_none_cnts.clear()
        return True
Exemplo n.º 18
0
class IntervalDict(MutableMapping):
    """
    An IntervalDict is a dict-like data structure that maps from intervals to data,
    where keys can be single values or Interval instances.

    When keys are Interval instances, its behaviour merely corresponds to
    range queries and it returns IntervalDict instances corresponding to the
    subset of values covered by the given interval. If no matching value is
    found, an empty IntervalDict is returned.
    When keys are "single values", its behaviour corresponds to the one of Python
    built-in dict. When no matchin value is found, a KeyError is raised.

    Note that this class does not aim to have the best performance, but is
    provided mainly for convenience. Its performance mainly depends on the
    number of distinct values (not keys) that are stored.
    """

    __slots__ = ("_storage", )

    def __init__(self, mapping_or_iterable=None):
        """
        Return a new IntervalDict.

        If no argument is given, an empty IntervalDict is created. If an argument
        is given, and is a mapping object (e.g., another IntervalDict), an
        new IntervalDict with the same key-value pairs is created. If an
        iterable is provided, it has to be a list of (key, value) pairs.

        :param mapping_or_iterable: optional mapping or iterable.
        """
        self._storage = SortedDict(_sort)  # Mapping from intervals to values

        if mapping_or_iterable is not None:
            self.update(mapping_or_iterable)

    @classmethod
    def _from_items(cls, items):
        """
        Fast creation of an IntervalDict with the provided items.

        The items have to satisfy the two following properties: (1) all keys
        must be disjoint intervals and (2) all values must be distinct.

        :param items: list of (key, value) pairs.
        :return: an IntervalDict
        """
        d = cls()
        for key, value in items:
            d._storage[key] = value

        return d

    def clear(self):
        """
        Remove all items from the IntervalDict.
        """
        self._storage.clear()

    def copy(self):
        """
        Return a shallow copy.

        :return: a shallow copy.
        """
        return IntervalDict._from_items(self.items())

    def get(self, key, default=None):
        """
        Return the values associated to given key.

        If the key is a single value, it returns a single value (if it exists) or
        the default value. If the key is an Interval, it returns a new IntervalDict
        restricted to given interval. In that case, the default value is used to
        "fill the gaps" (if any) w.r.t. given key.

        :param key: a single value or an Interval instance.
        :param default: default value (default to None).
        :return: an IntervalDict, or a single value if key is not an Interval.
        """
        if isinstance(key, Interval):
            d = self[key]
            d[key - d.domain()] = default
            return d
        else:
            try:
                return self[key]
            except KeyError:
                return default

    def find(self, value):
        """
        Return a (possibly empty) Interval i such that self[i] = value, and
        self[~i] != value.

        :param value: value to look for.
        :return: an Interval instance.
        """
        return Interval(*(i for i, v in self._storage.items() if v == value))

    def items(self):
        """
        Return a view object on the contained items sorted by key
        (see https://docs.python.org/3/library/stdtypes.html#dict-views).

        :return: a view object.
        """
        return self._storage.items()

    def keys(self):
        """
        Return a view object on the contained keys (sorted)
        (see https://docs.python.org/3/library/stdtypes.html#dict-views).

        :return: a view object.
        """
        return self._storage.keys()

    def values(self):
        """
        Return a view object on the contained values sorted by key
        (see https://docs.python.org/3/library/stdtypes.html#dict-views).

        :return: a view object.
        """
        return self._storage.values()

    def domain(self):
        """
        Return an Interval corresponding to the domain of this IntervalDict.

        :return: an Interval.
        """
        return Interval(*self._storage.keys())

    def pop(self, key, default=None):
        """
        Remove key and return the corresponding value if key is not an Interval.
        If key is an interval, it returns an IntervalDict instance.

        This method combines self[key] and del self[key]. If a default value
        is provided and is not None, it uses self.get(key, default) instead of
        self[key].

        :param key: a single value or an Interval instance.
        :param default: optional default value.
        :return: an IntervalDict, or a single value if key is not an Interval.
        """
        if default is None:
            value = self[key]
            del self[key]
            return value
        else:
            value = self.get(key, default)
            try:
                del self[key]
            except KeyError:
                pass
            return value

    def popitem(self):
        """
        Remove and return some (key, value) pair as a 2-tuple.
        Raise KeyError if D is empty.

        :return: a (key, value) pair.
        """
        return self._storage.popitem()

    def setdefault(self, key, default=None):
        """
        Return given key. If it does not exist, set its value to default and
        return it.

        :param key: a single value or an Interval instance.
        :param default: default value (default to None).
        :return: an IntervalDict, or a single value if key is not an Interval.
        """
        if isinstance(key, Interval):
            value = self.get(key, default)
            self.update(value)
            return value
        else:
            try:
                return self[key]
            except KeyError:
                self[key] = default
                return default

    def update(self, mapping_or_iterable):
        """
        Update current IntervalDict with provided values.

        If a mapping is provided, it must map Interval instances to values (e.g.,
        another IntervalDict). If an iterable is provided, it must consist of a
        list of (key, value) pairs.

        :param mapping_or_iterable: mapping or iterable.
        """
        if isinstance(mapping_or_iterable, Mapping):
            data = mapping_or_iterable.items()
        else:
            data = mapping_or_iterable

        for i, v in data:
            self[i] = v

    def combine(self, other, how):
        """
        Return a new IntervalDict that combines the values from current and
        provided ones.

        If d = d1.combine(d2, f), then d contains (1) all values from d1 whose
        keys do not intersect the ones of d2, (2) all values from d2 whose keys
        do not intersect the ones of d1, and (3) f(x, y) for x in d1, y in d2 for
        intersecting keys.

        :param other: another IntervalDict instance.
        :param how: a function of two parameters that combines values.
        :return: a new IntervalDict instance.
        """
        new_items = []

        dom1, dom2 = self.domain(), other.domain()

        new_items.extend(self[dom1 - dom2].items())
        new_items.extend(other[dom2 - dom1].items())

        intersection = dom1 & dom2
        d1, d2 = self[intersection], other[intersection]

        for i1, v1 in d1.items():
            for i2, v2 in d2.items():
                if i1.overlaps(i2):
                    i = i1 & i2
                    v = how(v1, v2)
                    new_items.append((i, v))

        return IntervalDict(new_items)

    def as_dict(self, atomic=False):
        """
        Return the content as a classical Python dict.

        :param atomic: whether keys are atomic intervals.
        :return: a Python dict.
        """
        if atomic:
            d = dict()
            for interval, v in self._storage.items():
                for i in interval:
                    d[i] = v
            return d
        else:
            return dict(self._storage)

    def __getitem__(self, key):
        if isinstance(key, Interval):
            items = []
            for i, v in self._storage.items():
                # Early out
                if key.upper < i.lower:
                    break

                intersection = key & i
                if not intersection.empty:
                    items.append((intersection, v))
            return IntervalDict._from_items(items)
        else:
            for i, v in self._storage.items():
                if key in i:
                    return v
            raise KeyError(key)

    def __setitem__(self, key, value):
        interval = key if isinstance(key, Interval) else singleton(key)

        if interval.empty:
            return

        removed_keys = []
        added_items = []

        found = False
        for i, v in self._storage.items():
            if value == v:
                found = True
                # Extend existing key
                removed_keys.append(i)
                added_items.append((i | interval, v))
            elif i.overlaps(interval):
                # Reduce existing key
                remaining = i - interval
                removed_keys.append(i)
                if not remaining.empty:
                    added_items.append((remaining, v))

        if not found:
            added_items.append((interval, value))

        # Update storage accordingly
        for key in removed_keys:
            self._storage.pop(key)

        for key, value in added_items:
            self._storage[key] = value

    def __delitem__(self, key):
        interval = key if isinstance(key, Interval) else singleton(key)

        if interval.empty:
            return

        removed_keys = []
        added_items = []

        found = False
        for i, v in self._storage.items():
            # Early out
            if interval.upper < i.lower:
                break

            if i.overlaps(interval):
                found = True
                remaining = i - interval
                removed_keys.append(i)
                if not remaining.empty:
                    added_items.append((remaining, v))

        if not found and not isinstance(key, Interval):
            raise KeyError(key)

        # Update storage accordingly
        for key in removed_keys:
            self._storage.pop(key)

        for key, value in added_items:
            self._storage[key] = value

    def __or__(self, other):
        d = self.copy()
        d.update(other)
        return d

    def __ior__(self, other):
        self.update(other)
        return self

    def __iter__(self):
        return iter(self._storage)

    def __len__(self):
        return len(self._storage)

    def __contains__(self, key):
        return key in self.domain()

    def __repr__(self):
        return "{}{}{}".format(
            "{",
            ", ".join("{!r}: {!r}".format(i, v) for i, v in self.items()),
            "}",
        )

    def __eq__(self, other):
        if isinstance(other, IntervalDict):
            return self.as_dict() == other.as_dict()
        else:
            return NotImplemented
Exemplo n.º 19
0
class QLinearViewer(QWidget):
    def __init__(self, workspace, disasm_view, parent=None):
        super(QLinearViewer, self).__init__(parent)

        self.workspace = workspace
        self.disasm_view = disasm_view

        self.objects = []  # Objects that will be painted

        self.cfg = None
        self.cfb = None

        self._offset_to_region = SortedDict()
        self._addr_to_region_offset = SortedDict()

        # Offset (in bytes) into the entire blanket view
        self._offset = 0
        # The maximum offset (in bytes) of the blanket view
        self._max_offset = None
        # The first line that is rendered of the first object in self.objects. Start from 0.
        self._start_line_in_object = 0

        self._linear_view = None  # type: QLinearGraphicsView
        self._disasms = {}

        self._init_widgets()

    #
    # Properties
    #

    @property
    def offset(self):
        return self._offset

    @offset.setter
    def offset(self, v):
        self._offset = v

    @property
    def start_line_in_object(self):
        return self._start_line_in_object

    @property
    def max_offset(self):
        if self._max_offset is None:
            self._max_offset = self._calculate_max_offset()
        return self._max_offset

    #
    # Proxy properties
    #

    @property
    def selected_operands(self):
        return self._linear_view.selected_operands

    @property
    def selected_insns(self):
        return self._linear_view.selected_insns

    #
    # Public methods
    #

    def initialize(self):

        if self.cfb is None:
            return

        self._addr_to_region_offset.clear()
        self._offset_to_region.clear()
        self._disasms.clear()
        self._offset = 0
        self._max_offset = None
        self._start_line_in_object = 0

        # enumerate memory regions
        byte_offset = 0
        for mr in self.cfb.regions:  # type:MemoryRegion
            self._addr_to_region_offset[mr.addr] = byte_offset
            self._offset_to_region[byte_offset] = mr
            byte_offset += mr.size

    def navigate_to_addr(self, addr):
        if not self._addr_to_region_offset:
            return
        try:
            floor_region_addr = next(
                self._addr_to_region_offset.irange(maximum=addr, reverse=True))
        except StopIteration:
            floor_region_addr = next(self._addr_to_region_offset.irange())
        floor_region_offset = self._addr_to_region_offset[floor_region_addr]

        offset_into_region = addr - floor_region_addr
        self.navigate_to(floor_region_offset + offset_into_region)

    def refresh(self):
        self._linear_view.refresh()

    def navigate_to(self, offset):

        self._linear_view.navigate_to(int(offset))

        self.prepare_objects(offset)

        self._linear_view.refresh()

    def prepare_objects(self, offset, start_line=0):
        """
        Prepare objects to print based on offset and start_line. Update self.objects, self._offset, and
        self._start_line_in_object.

        :param int offset:      Beginning offset (in bytes) to display in the linear viewer.
        :param int start_line:  The first line into the first object to display in the linear viewer.
        :return:                None
        """

        if offset == self._offset and start_line == self._start_line_in_object:
            return

        # Convert the offset to memory region
        base_offset, mr = self._region_from_offset(
            offset)  # type: int,MemoryRegion
        if mr is None:
            return

        addr = self._addr_from_offset(mr, base_offset, offset)
        _l.debug("Address %#x, offset %d, start_line %d.", addr, offset,
                 start_line)

        if start_line < 0:
            # Which object are we currently displaying at the top of the disassembly view?
            try:
                top_obj_addr = self.cfb.floor_addr(addr=addr)
            except KeyError:
                top_obj_addr = addr

            # Reverse-iterate until we have enough lines to compensate start_line
            for obj_addr, obj in self.cfb.ceiling_items(addr=top_obj_addr,
                                                        reverse=True,
                                                        include_first=False):
                qobject = self._obj_to_paintable(obj_addr, obj)
                if qobject is None:
                    continue
                object_lines = int(qobject.height //
                                   self._linear_view.line_height())
                _l.debug(
                    "Compensating negative start_line: object %s, object_lines %d.",
                    obj, object_lines)
                start_line += object_lines
                if start_line >= 0:
                    addr = obj_addr
                    # Update offset
                    new_region_addr = next(
                        self._addr_to_region_offset.irange(maximum=addr,
                                                           reverse=True))
                    new_region_offset = self._addr_to_region_offset[
                        new_region_addr]
                    offset = (addr - new_region_addr) + new_region_offset
                    break
            else:
                # umm we don't have enough objects to compensate the negative start_line
                start_line = 0
                # update addr and offset to their minimal values
                addr = next(self._addr_to_region_offset.irange())
                offset = self._addr_to_region_offset[addr]

        _l.debug("After adjustment: Address %#x, offset %d, start_line %d.",
                 addr, offset, start_line)

        self.objects = []

        viewable_lines = int(self._linear_view.height() //
                             self._linear_view.line_height())
        lines = 0
        start_line_in_object = 0

        # Load a page of objects
        for obj_addr, obj in self.cfb.floor_items(addr=addr):
            qobject = self._obj_to_paintable(obj_addr, obj)
            if qobject is None:
                # Conversion failed
                continue

            if isinstance(qobject, QBlock):
                for insn_addr in qobject.addr_to_insns.keys():
                    self._linear_view._add_insn_addr_block_mapping(
                        insn_addr, qobject)

            object_lines = int(qobject.height //
                               self._linear_view.line_height())

            if start_line >= object_lines:
                # this object should be skipped. ignore it
                start_line -= object_lines
                # adjust the offset as well
                if obj_addr <= addr < obj_addr + obj.size:
                    offset += obj_addr + obj.size - addr
                else:
                    offset += obj.size
                _l.debug("Skipping object %s (size %d). New offset: %d.", obj,
                         obj.size, offset)
            else:
                if start_line > 0:
                    _l.debug(
                        "First object to paint: %s (size %d). Current offset %d.",
                        obj, obj.size, offset)
                    # this is the first object to paint
                    start_line_in_object = start_line
                    start_line = 0
                    lines += object_lines - start_line_in_object
                else:
                    lines += object_lines
                self.objects.append(qobject)

            if lines > viewable_lines:
                break

        _l.debug("Final offset %d, start_line_in_object %d.", offset,
                 start_line_in_object)

        # Update properties
        self._offset = offset
        self._start_line_in_object = start_line_in_object

    #
    # Private methods
    #

    def _init_widgets(self):

        self._linear_view = QLinearGraphicsView(self, self.disasm_view)

        layout = QHBoxLayout()
        layout.addWidget(self._linear_view)
        layout.setContentsMargins(0, 0, 0, 0)

        self.setLayout(layout)

        # Setup proxy methods
        self.update_label = self._linear_view.update_label
        self.select_instruction = self._linear_view.select_instruction
        self.unselect_instruction = self._linear_view.unselect_instruction
        self.unselect_all_instructions = self._linear_view.unselect_all_instructions
        self.select_operand = self._linear_view.select_operand
        self.unselect_operand = self._linear_view.unselect_operand
        self.unselect_all_operands = self._linear_view.unselect_all_operands
        self.show_selected = self._linear_view.show_selected
        self.show_instruction = self._linear_view.show_instruction

    def _obj_to_paintable(self, obj_addr, obj):
        if isinstance(obj, Block):
            cfg_node = self.cfg.get_any_node(obj.addr, force_fastpath=True)
            if cfg_node is not None:
                func_addr = cfg_node.function_address
                func = self.cfg.kb.functions[func_addr]  # FIXME: Resiliency
                disasm = self._get_disasm(func)
                qobject = QBlock(
                    self.workspace,
                    func_addr,
                    self.disasm_view,
                    disasm,
                    self.disasm_view.infodock,
                    obj.addr,
                    [obj],
                    {},
                    mode='linear',
                )
            else:
                # TODO: This should be displayed as a function thunk
                _l.error(
                    "QLinearViewer: Unexpected result: CFGNode %#x is not found in CFG."
                    "Display it as a QUnknownBlock.", obj.addr)
                qobject = QUnknownBlock(self.workspace, obj_addr, obj.bytes)

        elif isinstance(obj, Unknown):
            qobject = QUnknownBlock(self.workspace, obj_addr, obj.bytes)

        else:
            qobject = None

        return qobject

    def _calculate_max_offset(self):
        try:
            max_off = next(self._offset_to_region.irange(reverse=True))
            mr = self._offset_to_region[max_off]  # type: MemoryRegion
            return max_off + mr.size
        except StopIteration:
            return 0

    def _region_from_offset(self, offset):
        try:
            off = next(
                self._offset_to_region.irange(maximum=offset, reverse=True))
            return off, self._offset_to_region[off]
        except StopIteration:
            return None, None

    def _addr_from_offset(self, mr, base_offset, offset):
        return mr.addr + (offset - base_offset)

    def _get_disasm(self, func):
        """

        :param func:
        :return:
        """

        if func.addr not in self._disasms:
            self._disasms[
                func.
                addr] = self.workspace.instance.project.analyses.Disassembly(
                    function=func)
        return self._disasms[func.addr]
Exemplo n.º 20
0
class PageWidget(QWidget):
    move_drop_event = pyqtSignal(object, int, int)
    copy_drop_event = pyqtSignal(object, int, int)

    DRAG_MAGIC = 'LiSP_Drag&Drop'

    def __init__(self, rows, columns, *args):
        super().__init__(*args)
        self.setAcceptDrops(True)

        self.__rows = rows
        self.__columns = columns
        self.__widgets = SortedDict()

        self.setLayout(QGridLayout())
        self.layout().setContentsMargins(4, 4, 4, 4)
        self.init_layout()

    def init_layout(self):
        for row in range(0, self.__rows):
            self.layout().setRowStretch(row, 1)
            # item = QSpacerItem(0, 0, QSizePolicy.Minimum, QSizePolicy.Expanding)
            # self.layout().addItem(item, row, 0)

        for column in range(0, self.__columns):
            self.layout().setColumnStretch(column, 1)
            # item = QSpacerItem(0, 0, QSizePolicy.Expanding, QSizePolicy.Minimum)
            # self.layout().addItem(item, 0, column)

    def add_widget(self, widget, row, column):
        self._check_index(row, column)
        if (row, column) not in self.__widgets:
            widget.setSizePolicy(QSizePolicy.Ignored, QSizePolicy.Ignored)
            self.__widgets[(row, column)] = widget
            self.layout().addWidget(widget, row, column)
            widget.show()
        else:
            raise IndexError('cell {} already used'.format((row, column)))

    def take_widget(self, row, column):
        self._check_index(row, column)
        if (row, column) in self.__widgets:
            widget = self.__widgets.pop((row, column))
            widget.hide()
            self.layout().removeWidget(widget)
            return widget
        else:
            raise IndexError('cell {} is empty'.format((row, column)))

    def move_widget(self, o_row, o_column, n_row, n_column):
        widget = self.take_widget(o_row, o_column)
        self.add_widget(widget, n_row, n_column)

    def widget(self, row, column):
        self._check_index(row, column)
        return self.__widgets.get((row, column))

    def index(self, widget):
        for index, f_widget in self.__widgets.items():
            if widget is f_widget:
                return index

        return -1, -1

    def widgets(self):
        return iter(self.__widgets.values())

    def reset(self):
        self.__widgets.clear()

    def dragEnterEvent(self, event):
        if event.mimeData().hasText():
            if event.mimeData().text() == PageWidget.DRAG_MAGIC:
                event.accept()
            else:
                event.ignore()
        else:
            event.ignore()

    def dragLeaveEvent(self, event):
        event.ignore()

    def dropEvent(self, event):
        row, column = self._event_index(event)
        if self.layout().itemAtPosition(row, column) is None:
            if qApp.keyboardModifiers() == Qt.ControlModifier:
                event.setDropAction(Qt.MoveAction)
                event.accept()
                self.move_drop_event.emit(event.source(), row, column)
            elif qApp.keyboardModifiers() == Qt.ShiftModifier:
                event.setDropAction(Qt.CopyAction)
                self.copy_drop_event.emit(event.source(), row, column)
                event.accept()

        event.ignore()

    def dragMoveEvent(self, event):
        row, column = self._event_index(event)
        if self.layout().itemAtPosition(row, column) is None:
            event.accept()
        else:
            event.ignore()

    def _check_index(self, row, column):
        if not isinstance(row, int):
            raise TypeError('rows index must be integers, not {}'.format(
                row.__class__.__name__))
        if not isinstance(column, int):
            raise TypeError('columns index must be integers, not {}'.format(
                column.__class__.__name__))

        if not 0 <= row < self.__rows or not 0 <= column < self.__columns:
            raise IndexError('index out of bound {}'.format((row, column)))

    def _event_index(self, event):
        # Margins and spacings are equals
        space = self.layout().horizontalSpacing()
        margin = self.layout().contentsMargins().right()

        r_size = (self.height() + margin * 2) // self.__rows + space
        c_size = (self.width() + margin * 2) // self.__columns + space

        row = math.ceil(event.pos().y() / r_size) - 1
        column = math.ceil(event.pos().x() / c_size) - 1

        return row, column
Exemplo n.º 21
0
class MemoryIndex(object):
    ''' An ad-hoc memory index file for when an on-file index file does not
    exist. '''
    def __init__(self, tbl, colind):
        self.__idx = SortedDict(Key)
        init_idx_with(self, tbl, colind)

    def add(self, rowid, key):
        if key not in self.__idx:
            val = set()
            self.__idx[key] = val
        else:
            val = self.__idx[key]
        val.add(rowid)

    def clear(self):
        self.__idx.clear()

    def search(self, key, inequality):
        if inequality == '!=':
            for k in self.__idx:
                if compare(k, key, '!='):
                    yield from self.__idx[k]
        elif inequality in ('=', '=='):
            if key in self.__idx:
                yield from self.__idx[key]
        else:
            opers = {
                '<=': {
                    'minimum': None,
                    'maximum': key
                },
                '<': {
                    'minimum': None,
                    'maximum': key,
                    'inclusive': (False, False)
                },
                '>=': {
                    'minimum': key,
                    'maximum': None
                },
                '>': {
                    'minimum': key,
                    'maximum': None,
                    'inclusive': (False, False)
                }
            }
            for k in self.__idx.irange(**opers[inequality]):
                yield from self.__idx[k]

    def modify(self, old_rowid, new_rowid, key):
        if key not in self.__idx:
            return False
        s = self.__idx[key]
        if rowid not in s:
            return False
        s.remove(old_rowid)
        s.add(new_rowid)
        return True

    def delete(self, rowid, key):
        if key not in self.__idx:
            return False
        s = self.__idx[key]
        if rowid not in s:
            return False
        s.remove(rowid)
        return True
Exemplo n.º 22
0
class CacheStore(object):
    class CacheItem(object):
        def __init__(self):
            self.valid = Event()
            self.data = None

    def __init__(self, key=None):
        self.lock = RLock()
        self.store = SortedDict(key)

    def __getitem__(self, item):
        return self.get(item)

    def put(self, key, data):
        with self.lock:
            item = self.store[key] if key in self.store else self.CacheItem()
            item.data = data
            item.valid.set()

            if key not in self.store:
                self.store[key] = item
                return True

            return False

    def update(self, **kwargs):
        with self.lock:
            items = {}
            created = []
            updated = []
            for k, v in kwargs.items():
                items[k] = self.CacheItem()
                items[k].data = v
                items[k].valid.set()
                if k in self.store:
                    updated.append(k)
                else:
                    created.append(k)

            self.store.update(**items)
            return created, updated

    def update_one(self, key, **kwargs):
        with self.lock:
            item = self.get(key)
            if not item:
                return False

            for k, v in kwargs.items():
                set(item, k, v)

            self.put(key, item)
            return True

    def update_many(self, key, predicate, **kwargs):
        with self.lock:
            updated = []
            for k, v in self.itervalid():
                if predicate(v):
                    if self.update_one(k, **kwargs):
                        updated.append(key)

            return updated

    def get(self, key, default=None, timeout=None):
        item = self.store.get(key)
        if item:
            item.valid.wait(timeout)
            return item.data

        return default

    def remove(self, key):
        with self.lock:
            if key in self.store:
                del self.store[key]
                return True

            return False

    def remove_many(self, keys):
        with self.lock:
            removed = []
            for key in keys:
                if key in self.store:
                    del self.store[key]
                    removed.append(key)

            return removed

    def clear(self):
        with self.lock:
            items = list(self.store.keys())
            self.store.clear()
            return items

    def exists(self, key):
        return key in self.store

    def rename(self, oldkey, newkey):
        with self.lock:
            obj = self.get(oldkey)
            obj['id'] = newkey
            self.put(newkey, obj)
            self.remove(oldkey)

    def is_valid(self, key):
        item = self.store.get(key)
        if item:
            return item.valid.is_set()

        return False

    def invalidate(self, key):
        with self.lock:
            item = self.store.get(key)
            if item:
                item.valid.clear()

    def itervalid(self):
        for key, value in list(self.store.items()):
            if value.valid.is_set():
                yield (key, value.data)

    def validvalues(self):
        for value in list(self.store.values()):
            if value.valid.is_set():
                yield value.data

    def remove_predicate(self, predicate):
        result = []
        for k, v in self.itervalid():
            if predicate(v):
                self.remove(k)
                result.append(k)

        return result

    def query(self, *filter, **params):
        return query(list(self.validvalues()), *filter, **params)
Exemplo n.º 23
0
class QLinearDisassembly(QDisassemblyBaseControl, QAbstractScrollArea):
    OBJECT_PADDING = 0

    def __init__(self, workspace, disasm_view, parent=None):
        QDisassemblyBaseControl.__init__(self, workspace, disasm_view, QAbstractScrollArea)
        QAbstractScrollArea.__init__(self, parent=parent)

        self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
        self.horizontalScrollBar().setSingleStep(Conf.disasm_font_width)
        self.verticalScrollBar().setSingleStep(16)

        # self.setTransformationAnchor(QGraphicsView.NoAnchor)
        # self.setResizeAnchor(QGraphicsView.NoAnchor)
        # self.setAlignment(Qt.AlignLeft)

        self._viewer = None  # type: QLinearDisassemblyView

        self._line_height = Conf.disasm_font_height

        self._offset_to_region = SortedDict()
        self._addr_to_region_offset = SortedDict()
        # Offset (in bytes) into the entire blanket view
        self._offset = 0
        # The maximum offset (in bytes) of the blanket view
        self._max_offset = None
        # The first line that is rendered of the first object in self.objects. Start from 0.
        self._start_line_in_object = 0

        self._disasms = { }
        self.objects = [ ]

        self.verticalScrollBar().actionTriggered.connect(self._on_vertical_scroll_bar_triggered)

        self._init_widgets()

    def reload(self):
        self.initialize()

    #
    # Properties
    #

    @property
    def offset(self):
        return self._offset

    @offset.setter
    def offset(self, v):
        self._offset = v

    @property
    def max_offset(self):
        if self._max_offset is None:
            self._max_offset = self._calculate_max_offset()
        return self._max_offset

    @property
    def cfg(self):
        return self.workspace.instance.cfg

    @property
    def cfb(self):
        return self.workspace.instance.cfb

    @property
    def scene(self):
        return self._viewer._scene

    #
    # Events
    #

    def resizeEvent(self, event):
        old_height = event.oldSize().height()
        new_height = event.size().height()
        self._viewer._scene.setSceneRect(QRectF(0, 0, event.size().width(), new_height))

        if new_height > old_height:
            # we probably need more objects generated
            curr_offset = self._offset
            self._offset = None  # force a re-generation of objects
            self.prepare_objects(curr_offset, start_line=self._start_line_in_object)
            self.redraw()

        super().resizeEvent(event)

    def wheelEvent(self, event):
        """
        :param QWheelEvent event:
        :return:
        """
        delta = event.delta()
        if delta < 0:
            # scroll down by some lines
            lines = min(int(-delta // self._line_height), 3)
            self.prepare_objects(self.offset, start_line=self._start_line_in_object + lines)
        elif delta > 0:
            # Scroll up by some lines
            lines = min(int(delta // self._line_height), 3)
            self.prepare_objects(self.offset, start_line=self._start_line_in_object - lines)

        self.verticalScrollBar().setValue(self.offset * self._line_height)
        event.accept()
        self.viewport().update()

    def _on_vertical_scroll_bar_triggered(self, action):

        if action == QAbstractSlider.SliderSingleStepAdd:
            # scroll down by one line
            self.prepare_objects(self.offset, start_line=self._start_line_in_object + 1)
            self.viewport().update()
        elif action == QAbstractSlider.SliderSingleStepSub:
            # Scroll up by one line
            self.prepare_objects(self.offset, start_line=self._start_line_in_object - 1)
            self.viewport().update()
        elif action == QAbstractSlider.SliderPageStepAdd:
            # Scroll down by one page
            lines_per_page = int(self.height() // self._line_height)
            self.prepare_objects(self.offset, start_line=self._start_line_in_object
                                                                       + lines_per_page)
            self.viewport().update()
        elif action == QAbstractSlider.SliderPageStepSub:
            # Scroll up by one page
            lines_per_page = int(self.height() // self._line_height)
            self.prepare_objects(self.offset,
                                        start_line=self._start_line_in_object - lines_per_page)
            self.viewport().update()
        elif action == QAbstractSlider.SliderMove:
            # Setting a new offset
            new_offset = int(self.verticalScrollBar().value() // self._line_height)
            self.prepare_objects(new_offset)
            self.viewport().update()

    #
    # Public methods
    #

    def redraw(self):
        if self._viewer is not None:
            self._viewer.redraw()

    def refresh(self):
        self._update_size()
        self.redraw()

    def initialize(self):

        if self.cfb is None:
            return

        self._addr_to_region_offset.clear()
        self._offset_to_region.clear()
        self._disasms.clear()
        self._offset = None
        self._max_offset = None
        self._start_line_in_object = 0

        # enumerate memory regions
        byte_offset = 0
        for mr in self.cfb.regions:  # type: MemoryRegion
            if mr.type in {'tls', 'kernel'}:
                # Skip TLS objects and kernel objects
                continue
            self._addr_to_region_offset[mr.addr] = byte_offset
            self._offset_to_region[byte_offset] = mr
            byte_offset += mr.size

        self._update_size()

    def goto_function(self, func):
        if func.addr not in self._block_addr_map:
            _l.error('Unable to find entry block for function %s', func)
        view_height = self.viewport().height()
        desired_center_y = self._block_addr_map[func.addr].pos().y()
        _l.debug('Going to function at 0x%x by scrolling to %s', func.addr, desired_center_y)
        self.verticalScrollBar().setValue(desired_center_y - (view_height / 3))

    def show_instruction(self, insn_addr, insn_pos=None, centering=False, use_block_pos=False):
        """

        :param insn_addr:
        :param QGraphicsItem item:
        :param centering:
        :param use_block_pos:
        :return:
        """

        if insn_pos is not None:
            # check if item is already visible in the viewport
            viewport = self._viewer.viewport()
            rect = self._viewer.mapToScene(QRect(0, 0, viewport.width(), viewport.height())).boundingRect()
            if rect.contains(insn_pos):
                return

        self.navigate_to_addr(insn_addr)

    def navigate_to_addr(self, addr):
        if not self._addr_to_region_offset:
            return
        try:
            floor_region_addr = next(self._addr_to_region_offset.irange(maximum=addr, reverse=True))
        except StopIteration:
            floor_region_addr = next(self._addr_to_region_offset.irange())
        floor_region_offset = self._addr_to_region_offset[floor_region_addr]

        offset_into_region = addr - floor_region_addr
        self.navigate_to(floor_region_offset + offset_into_region)

    def navigate_to(self, offset):
        self.verticalScrollBar().setValue(offset * self._line_height)
        self.prepare_objects(offset, start_line=0)

    #
    # Private methods
    #

    def _init_widgets(self):
        self._viewer = QLinearDisassemblyView(self)

        layout = QHBoxLayout()
        layout.addWidget(self._viewer)
        layout.setContentsMargins(0, 0, 0, 0)

        self.setLayout(layout)

    def _update_size(self):

        # ask all objects to update their sizes
        for obj in self.objects:
            obj.clear_cache()
            obj.refresh()

        # update vertical scrollbar
        self.verticalScrollBar().setRange(0, self.max_offset * self._line_height - self.height() // 2)
        offset = 0 if self.offset is None else self.offset
        self.verticalScrollBar().setValue(offset * self._line_height)

    def clear_objects(self):
        self.objects.clear()
        self._offset = None

    def prepare_objects(self, offset, start_line=0):
        """
        Prepare objects to print based on offset and start_line. Update self.objects, self._offset, and
        self._start_line_in_object.
        :param int offset:      Beginning offset (in bytes) to display in the linear viewer.
        :param int start_line:  The first line into the first object to display in the linear viewer.
        :return:                None
        """

        if offset is None:
            offset = 0

        if offset == self._offset and start_line == self._start_line_in_object:
            return

        # Convert the offset to memory region
        base_offset, mr = self._region_from_offset(offset)  # type: int, MemoryRegion
        if mr is None:
            return

        addr = self._addr_from_offset(mr, base_offset, offset)
        _l.debug("Address %#x, offset %d, start_line %d.", addr, offset, start_line)

        self._insaddr_to_block.clear()

        if start_line < 0:
            # Which object are we currently displaying at the top of the disassembly view?
            try:
                top_obj_addr = self.cfb.floor_addr(addr=addr)
            except KeyError:
                top_obj_addr = addr

            # Reverse-iterate until we have enough lines to compensate start_line
            for obj_addr, obj in self.cfb.ceiling_items(addr=top_obj_addr, reverse=True, include_first=False):
                qobject = self._obj_to_paintable(obj_addr, obj)
                if qobject is None:
                    continue
                object_lines = int(qobject.height // self._line_height)
                _l.debug("Compensating negative start_line: object %s, object_lines %d.", obj, object_lines)
                start_line += object_lines
                if start_line >= 0:
                    addr = obj_addr
                    # Update offset
                    new_region_addr = next(self._addr_to_region_offset.irange(maximum=addr, reverse=True))
                    new_region_offset = self._addr_to_region_offset[new_region_addr]
                    offset = (addr - new_region_addr) + new_region_offset
                    break
            else:
                # umm we don't have enough objects to compensate the negative start_line
                start_line = 0
                # update addr and offset to their minimal values
                addr = next(self._addr_to_region_offset.irange())
                offset = self._addr_to_region_offset[addr]

        _l.debug("After adjustment: Address %#x, offset %d, start_line %d.", addr, offset, start_line)

        scene = self.scene
        # remove existing objects
        for obj in self.objects:
            scene.removeItem(obj)
        self.objects = [ ]

        viewable_lines = int(self.height() // self._line_height)
        lines = 0
        start_line_in_object = 0

        # Load a page of objects
        x = 80
        y = -start_line * self._line_height

        for obj_addr, obj in self.cfb.floor_items(addr=addr):
            qobject = self._obj_to_paintable(obj_addr, obj)
            _l.debug("Converted %s to %s at %x.", obj, qobject, obj_addr)
            if qobject is None:
                # Conversion failed
                continue

            if isinstance(qobject, QLinearBlock):
                for insn_addr in qobject.addr_to_insns.keys():
                    self._insaddr_to_block[insn_addr] = qobject

            # qobject.setCacheMode(QGraphicsItem.DeviceCoordinateCache)

            object_lines = int(qobject.height // self._line_height)

            if start_line >= object_lines:
                # this object should be skipped. ignore it
                start_line -= object_lines
                # adjust the offset as well
                if obj_addr <= addr < obj_addr + obj.size:
                    offset += obj_addr + obj.size - addr
                else:
                    offset += obj.size
                _l.debug("Skipping object %s (size %d). New offset: %d.", obj, obj.size, offset)
                y = -start_line * self._line_height
            else:
                if start_line > 0:
                    _l.debug("First object to paint: %s (size %d). Current offset %d. Start printing from line %d. "
                             "Y pos %d.", obj, obj.size, offset, start_line, y)
                    # this is the first object to paint
                    start_line_in_object = start_line
                    start_line = 0
                    lines += object_lines - start_line_in_object
                else:
                    lines += object_lines
                self.objects.append(qobject)
                qobject.setPos(x, y)
                scene.addItem(qobject)
                y += qobject.height + self.OBJECT_PADDING

            if lines > viewable_lines:
                break

        _l.debug("Final offset %d, start_line_in_object %d.", offset, start_line_in_object)

        # Update properties
        self._offset = offset
        self._start_line_in_object = start_line_in_object

    def _obj_to_paintable(self, obj_addr, obj):
        if isinstance(obj, Block):
            cfg_node = self.cfg.get_any_node(obj_addr, force_fastpath=True)
            func_addr = cfg_node.function_address
            if self.workspace.instance.kb.functions.contains_addr(func_addr):
                func = self.workspace.instance.kb.functions[func_addr]
                disasm = self._get_disasm(func)
                qobject = QLinearBlock(self.workspace, func_addr, self.disasm_view, disasm,
                                       self.disasm_view.infodock, obj.addr, [obj], {}, None, container=self._viewer,
                                       )
            else:
                # TODO: Get disassembly even if the function does not exist
                _l.warning("Function %s does not exist, and we cannot get disassembly for block %s.",
                           func_addr, obj)
                qobject = None
        elif isinstance(obj, MemoryData):
            qobject = QMemoryDataBlock(self.workspace, self.disasm_view.infodock, obj_addr, obj, parent=None,
                                       container=self._viewer)
        elif isinstance(obj, Unknown):
            qobject = QUnknownBlock(self.workspace, obj_addr, obj.bytes, container=self._viewer)
        else:
            qobject = None
        return qobject

    def _calculate_max_offset(self):
        try:
            max_off = next(self._offset_to_region.irange(reverse=True))
            mr = self._offset_to_region[max_off]  # type: MemoryRegion
            return max_off + mr.size
        except StopIteration:
            return 0

    def _region_from_offset(self, offset):
        try:
            off = next(self._offset_to_region.irange(maximum=offset, reverse=True))
            return off, self._offset_to_region[off]
        except StopIteration:
            return None, None

    def _addr_from_offset(self, mr, base_offset, offset):
        return mr.addr + (offset - base_offset)

    def _get_disasm(self, func):
        """

        :param func:
        :return:
        """

        if func.addr not in self._disasms:
            self._disasms[func.addr] = self.workspace.instance.project.analyses.Disassembly(function=func)
        return self._disasms[func.addr]
Exemplo n.º 24
0
class ProductOrderBook:
    def __init__(self, product_id):
        self.product_id = product_id
        self._asks = SortedDict(lambda key: float(key))
        self._asks_lock = Lock()
        self._bids = SortedDict(lambda key: neg(float(key)))
        self._bids_lock = Lock()
        self._first_bids_lock = Lock()
        self._first_bids_lock.acquire()
        self._first_asks_lock = Lock()
        self._first_asks_lock.acquire()
        self._update_callbacks = {}

    def add_update_callback(self, callback: Callable):
        """Add a callback to be called on every update. The callback will be called with 'self' as a parameter.

    Returns:
      A unique identifier (str) that can be used to remove the callback in the future.
    """
        identifier = str(uuid.uuid4())
        self._update_callbacks[identifier] = callback
        return identifier

    def remove_update_callback(self, identifier: Text):
        """Removes the callback by it's identifier."""
        del self._update_callbacks[identifier]

    def top_n_string(self, n=None):
        """Returns the "Top-N" asks/bids in the order-book in string form.

    Params:
      n: How many of the top
    """
        with self._bids_lock and self._asks_lock:
            return ProductOrderBook._make_formatted_string(
                bids=ProductOrderBook._make_sorted_dict_slice(self._bids,
                                                              stop=n),
                asks=ProductOrderBook._make_sorted_dict_slice(self._asks,
                                                              stop=n))

    def get_book(self, top_n=None):
        """Returns the order book as a dict with keys 'asks' and 'bids' and tuples of [price, size].

    Params:
      top_n: The depth of the order book to return.
    """
        return {
            'asks': self.get_asks(top_n=top_n),
            'bids': self.get_bids(top_n=top_n)
        }

    def get_asks(self, top_n=None):
        """Get the 'asks' part of the order book.

    Params:
      top_n: The depth of the order book to return.
    """
        with self._asks_lock:
            return ProductOrderBook._make_slice(self._asks, stop=top_n)

    def get_bids(self, top_n=None):
        """Get the 'bids' part of the order book.

        Params:
          top_n: The depth of the order book to return.
        """
        with self._bids_lock:
            return ProductOrderBook._make_slice(self._bids, stop=top_n)

    # Private API Below this Line.
    def _call_callbacks(self):
        for callback in self._update_callbacks.values():
            callback(self)

    def _init_bids(self, bids):
        with self._bids_lock:
            self._bids.clear()  # init should clear all current bids.
            for price, size in bids:
                self._bids[price] = float(size)
            self._first_bids_lock.release()
        self._call_callbacks()

    def _init_asks(self, asks):
        with self._asks_lock:
            self._asks.clear()  # init should clear all current asks.
            for price, size in asks:
                self._asks[price] = float(size)
            self._first_asks_lock.release()
        self._call_callbacks()

    def _consume_changes(self, changes):
        for side, price, size in changes:
            if side == 'buy':
                self._consume_buy(price, size)
            elif side == 'sell':
                self._consume_sell(price, size)
        self._call_callbacks()

    def _consume_buy(self, price, size):
        fsize = float(size)
        # Wait for _init_bids to run.
        if self._first_bids_lock.locked():
            self._first_bids_lock.acquire()
            self._first_bids_lock.release()
        with self._bids_lock:
            if str(fsize) == '0.0':
                del self._bids[price]
            else:
                self._bids[price] = fsize

    def _consume_sell(self, price, size):
        fsize = float(size)
        # Wait for _init_asks to run.
        if self._first_asks_lock.locked():
            self._first_asks_lock.acquire()
            self._first_asks_lock.release()
        with self._asks_lock:
            if str(fsize) == '0.0':
                del self._asks[price]
            else:
                self._asks[price] = fsize

    @staticmethod
    def _make_formatted_string(bids, asks):
        overall_format = "BIDS:\n{}\n\nASKS:\n{}\n\n"
        format_str = 'PRICE: {}, SIZE: {}'
        return overall_format.format(
            '\n'.join(
                format_str.format(str(price), str(bids[price]))
                for price in bids.keys()), '\n'.join(
                    format_str.format(str(price), str(asks[price]))
                    for price in asks.keys()))

    def __repr__(self):
        """Print the entire order book."""
        with self._asks_lock and self._bids_lock:
            return ProductOrderBook._make_formatted_string(
                self._bids, self._asks)

    @staticmethod
    def _make_sorted_dict_slice(orders: SortedDict, start=None, stop=None):
        return SortedDict(orders.key,
                          [(key, orders[key])
                           for key in orders.islice(start=start, stop=stop)])

    @staticmethod
    def _make_slice(orders: SortedDict, start=None, stop=None):
        return [(key, orders[key])
                for key in orders.islice(start=start, stop=stop)]
Exemplo n.º 25
0
def test6():
    """
    有序的map: SortedDict
    网址: http://www.grantjenks.com/docs/sortedcontainers/sorteddict.html
    """
    from sortedcontainers import SortedDict
    sd = SortedDict()
    # 插入、删除元素
    sd["wxx"] = 21
    sd["hh"] = 18
    sd["other"] = 20
    print(sd)  # SortedDict({'hh': 18, 'other': 20, 'wxx': 21})
    print(sd["wxx"])  # 访问不存在的键会报错, KeyError
    print(sd.get("c"))  # 访问不存在的键会返回None     None
    # SortedDict转dict
    print(dict(sd))  # {'hh': 18, 'other': 20, 'wxx': 21}
    # 返回最后一个元素和最后一个元素
    print(sd.peekitem(0))  # 类型tuple, 返回第一个元素    ('hh', 18)
    print(sd.peekitem())  # 类型tuple, 返回最后一个元素    ('wxx', 21)
    # 遍历
    for k, v in sd.items():
        print(k, ':', v, sep="", end=", ")  # sep取消每行输出之间的空格
    print()
    for k in sd:  # 遍历键k, 等价于for k in d.keys:
        print(str(k) + ":" + str(sd[k]), end=", ")
    print()
    for v in sd.values():  # 遍历值v
        print(v, end=", ")
    print()
    # 返回Map中的一个键
    print(sd.peekitem()[0])
    # 返回Map中的一个值
    print(sd.peekitem()[1])
    # 中判断某元素是否存在
    print("wxx" in sd)  # True
    # bisect_left() / bisect_right()
    sd["a"] = 1
    sd["c1"] = 2
    sd["c2"] = 4
    print(
        sd
    )  # SortedDict({'a': 1, 'c1': 2, 'c2': 4, 'hh': 18, 'other': 20, 'wxx': 21})
    print(sd.bisect_left("c1"))  # 返回键大于等于"c1"的最小元素对应的下标    1
    print(sd.bisect_right("c1"))  # 返回键大于"c1"的最小元素对应的下标    2
    # 清空
    sd.clear()
    print(len(sd))  # 0
    print(len(sd) == 0)  # True
    """
    无序的map: dict
    """
    print("---------------------------------------")
    d = {"c1": 2, "c2": 4, "hh": 18, "wxx": 21, 13: 14, 1: 0}
    print(d["wxx"])  # 21
    print(d[13])  # 14
    d[13] += 1
    print(d[13])  # 15
    d["future"] = "wonderful"  # 字典中添加键值对
    del d[1]  # 删除字典d中键1对应的数据值
    print("wxx" in d)  # 判断键"wxx"是否在字典d中,如果在返回True,否则False
    print(d.keys())  # 返回字典d中所有的键信息  dict_keys(['c1', 'c2', 'hh', 'wxx', 13])
    print(d.values())  # 返回字典d中所有的值信息  dict_values([2, 4, 18, 21, 14])
    print(d.items(
    ))  # dict_items([('c1', 2), ('c2', 4), ('hh', 18), ('wxx', 21), (13, 14)])
    for k, v in d.items():  # 遍历 k, v
        print(k, ':', v)
    for k in d:  # 遍历键k, 等价于for k in d.keys:
        print(str(k) + ":" + str(d[k]), end=", ")
    print()
    for v in d.values():  # 遍历值v
        print(v, end=", ")
    print()
    # 字典类型操作函数和方法
    print("---------------------------------------")
    d = {"中国": "北京", "美国": "华盛顿", "法国": "巴黎"}
    print(len(d))  # 返回字典d中元素的个数  3
    print(d.get("中国", "不存在"))  # 键k存在,则返回相应值,不在则返回<default>值  北京
    print(d.get("中", "不存在"))  # 不存在
    print(d.get("中"))  # None
    d["美国"] = "Washington"  # 修改键对应的值
    print(d.pop("美国"))  # 键k存在,则返回相应值,并将其从dict中删除
    print(d.popitem())  # 随机从字典d中取出一个键值对,以元组形式返回,并将其从dict中删除
    d.clear()  # 删除所有的键值对
Exemplo n.º 26
0
class CacheStore(object):
    class CacheItem(object):
        __slots__ = ('valid', 'data')

        def __init__(self):
            self.valid = Event()
            self.data = None

    def __init__(self, key=None):
        self.lock = RLock()
        self.store = SortedDict(key)

    def __getitem__(self, item):
        return self.get(item)

    def put(self, key, data):
        with self.lock:
            try:
                item = self.store[key]
                item.data = data
                item.valid.set()
                return False
            except KeyError:
                item = self.CacheItem()
                item.data = data
                item.valid.set()
                self.store[key] = item
                return True

    def update(self, **kwargs):
        with self.lock:
            items = {}
            created = []
            updated = []
            for k, v in kwargs.items():
                items[k] = self.CacheItem()
                items[k].data = v
                items[k].valid.set()
                if k in self.store:
                    updated.append(k)
                else:
                    created.append(k)

            self.store.update(**items)
            return created, updated

    def update_one(self, key, **kwargs):
        with self.lock:
            item = self.get(key)
            if not item:
                return False

            for k, v in kwargs.items():
                set(item, k, v)

            self.put(key, item)
            return True

    def update_many(self, key, predicate, **kwargs):
        with self.lock:
            updated = []
            for k, v in self.itervalid():
                if predicate(v):
                    if self.update_one(k, **kwargs):
                        updated.append(key)

            return updated

    def get(self, key, default=None, timeout=None):
        item = self.store.get(key)
        if item:
            item.valid.wait(timeout)
            return item.data

        return default

    def remove(self, key):
        with self.lock:
            try:
                del self.store[key]
                return True
            except KeyError:
                return False

    def remove_many(self, keys):
        with self.lock:
            removed = []
            for key in keys:
                try:
                    del self.store[key]
                    removed.append(key)
                except KeyError:
                    pass

            return removed

    def clear(self):
        with self.lock:
            items = list(self.store.keys())
            self.store.clear()
            return items

    def exists(self, key):
        return key in self.store

    def rename(self, oldkey, newkey):
        with self.lock:
            obj = self.get(oldkey)
            obj['id'] = newkey
            self.put(newkey, obj)
            self.remove(oldkey)

    def is_valid(self, key):
        item = self.store.get(key)
        if item:
            return item.valid.is_set()

        return False

    def invalidate(self, key):
        with self.lock:
            item = self.store.get(key)
            if item:
                item.valid.clear()

    def itervalid(self):
        for key, value in list(self.store.items()):
            if value.valid.is_set():
                yield (key, value.data)

    def validvalues(self):
        for value in list(self.store.values()):
            if value.valid.is_set():
                yield value.data

    def remove_predicate(self, predicate):
        result = []
        for k, v in self.itervalid():
            if predicate(v):
                self.remove(k)
                result.append(k)

        return result

    def query(self, *filter, **params):
        return query(list(self.validvalues()), *filter, **params)
Exemplo n.º 27
0
class FreshPondSim:
    def __init__(self,
                 distance,
                 start_time,
                 end_time,
                 entrances,
                 entrance_weights,
                 rand_velocities_and_distances_func,
                 entrance_rate,
                 entrance_rate_integral=None,
                 entrance_rate_integral_inverse=None,
                 interpolate_rate=True,
                 interpolate_rate_integral=True,
                 interpolate_res=None,
                 snap_exit=True):
        assert_positive_real(distance, 'distance')
        assert_real(start_time, 'start_time')
        assert_real(end_time, 'end_time')
        if not (start_time < end_time):
            raise ValueError(f"start_time should be less than end_time")
        assert len(entrances) == len(entrance_weights)
        self.start_time = start_time
        self.end_time = end_time
        self.dist_around = distance
        self.entrances = entrances
        self.entrance_weights = entrance_weights
        self.rand_velocities_and_distances = rand_velocities_and_distances_func
        self._snap_exit = snap_exit

        if interpolate_rate or interpolate_rate_integral:
            if interpolate_res is None:
                raise ValueError("Specify interpolate_res for interpolation")

        if interpolate_rate:
            self.entrance_rate = DynamicBoundedInterpolator(
                entrance_rate, start_time, end_time, interpolate_res)
        else:
            self.entrance_rate = entrance_rate

        if interpolate_rate_integral: # Want to interplate the integral function
            if entrance_rate_integral is None: # No integral function given
                # Do numerical integration and interpolate to speed it up
                def integral_func(t):
                    y, abserr = integrate.quad(entrance_rate, start_time, t)
                    return y

                self.entrance_rate_integral = DynamicBoundedInterpolator(
                    integral_func, start_time, end_time, interpolate_res)
            else: # Integral function was provided
                # Use the provided rate integral function but interpolate it
                self.entrance_rate_integral = DynamicBoundedInterpolator(
                    entrance_rate_integral, start_time, end_time, interpolate_res)
        else: # Don't want to interpolate the integral function
            # If entrance_rate_integral is not None (i.e. is provided) then
            # that function will be used as the rate integral.
            # If entrance_rate_integral is None, numerical integration will
            # be used.
            self.entrance_rate_integral = entrance_rate_integral

        self.entrance_rate_integral_inverse = entrance_rate_integral_inverse

        self.pedestrians = SortedKeyList(key=attrgetter('start_time'))
        
        self._counts = SortedDict()
        self._counts[self.start_time] = 0

        self._counts_are_correct = True

        self.refresh_pedestrians()

    def _distance(self, a, b):
        """signed distance of a relative to b"""
        return circular_diff(a % self.dist_around, b % self.dist_around,
                             self.dist_around)

    def _distance_from(self, b):
        """returns a function that returns the signed sitance from b"""
        return lambda a: self._distance(a, b)

    def _abs_distance_from(self, b):
        """returns a function that returns the distance from b"""
        return lambda a: abs(self._distance(a, b))

    def _closest_exit(self, dist):
        """Returns the closest number to dist that is equivalent mod dist_around
        to an element of entrances"""
        closest_exit = min(self.entrances, key=self._abs_distance_from(dist))
        diff = self._distance(closest_exit, dist)
        corrected_dist = dist + diff
        return corrected_dist

    def refresh_pedestrians(self):
        """Refreshes the pedestrians in the simulation to random ones"""
        self.clear_pedestrians()

        start_times = list(
            random_times(self.start_time, self.end_time,
                         self.entrance_rate,
                         self.entrance_rate_integral,
                         self.entrance_rate_integral_inverse))
        n_pedestrians = len(start_times)
        entrances = random.choices(population=self.entrances,
                                   weights=self.entrance_weights,
                                   k=n_pedestrians)
        velocities, distances = self.rand_velocities_and_distances(
            n_pedestrians).T

        def pedestrians_generator():
            for start_time, entrance, velocity, dist in zip(
                    start_times, entrances, velocities, distances):
                assert dist > 0
                if self._snap_exit:
                    original_exit = entrance + dist * sign(velocity)
                    corrected_exit = self._closest_exit(original_exit)
                    corrected_dist = abs(corrected_exit - entrance)
                    if math.isclose(corrected_dist, 0, abs_tol=1e-10):
                        corrected_dist = self.dist_around
                else:
                    corrected_dist = dist
                yield FreshPondPedestrian(self.dist_around, entrance,
                                          corrected_dist, start_time, velocity)

        self.add_pedestrians(pedestrians_generator())

    def clear_pedestrians(self):
        """Removes all pedestrains in the simulation"""
        self.pedestrians.clear()
        self._reset_counts()
        self._counts_are_correct = True

    def add_pedestrians(self, pedestrians):
        """Adds all the given pedestrians to the simulation"""
        def checked_pedestrians():
            for p in pedestrians:
                self._assert_pedestrian_in_range(p)
                yield p

        initial_num_pedestrians = self.num_pedestrians()
        self.pedestrians.update(checked_pedestrians())
        final_num_pedestrians = self.num_pedestrians()

        if final_num_pedestrians > initial_num_pedestrians:
            self._counts_are_correct = False
        else:
            assert final_num_pedestrians == initial_num_pedestrians

    def _assert_pedestrian_in_range(self, p):
        """Makes sure the pedestrian's start time is in the simulation's
        time interval"""
        if not (self.start_time <= p.start_time < self.end_time):
            raise ValueError(
                "Pedestrian start time is not in range [start_time, end_time)")

    def add_pedestrian(self, p):
        """Adds a new pedestrian to the simulation"""
        self._assert_pedestrian_in_range(p)
        self.pedestrians.add(p)

        # Update counts only when counts are correct
        if self._counts_are_correct:
            # add a new breakpoint at the pedestrian's start time if it not there
            self._counts[p.start_time] = self.n_people(p.start_time)

            # add a new breakpoint at the pedestrian's end time if it not there
            self._counts[p.end_time] = self.n_people(p.end_time)

            # increment all the counts in the pedestrian's interval of time
            # inclusive on the left, exclusive on the right
            # If it were inclusive on the right, then the count would be one more
            # than it should be in the period after end_time and before the next
            # breakpoint after end_time
            for t in self._counts.irange(p.start_time,
                                        p.end_time,
                                        inclusive=(True, False)):
                self._counts[t] += 1

    def _reset_counts(self):
        """Clears _counts and sets count at start_time to 0"""
        self._counts.clear()
        self._counts[self.start_time] = 0

    def _recompute_counts(self):
        """Store how many people there are whenever someone enters or exits so
        the number of people at a given time can be found quickly later"""
        # print("Recomputing counts")
        self._reset_counts()

        if self.num_pedestrians() == 0:
            return

        # pedestrians are already sorted by start time
        start_times = [p.start_time for p in self.pedestrians]
        end_times = sorted([p.end_time for p in self.pedestrians])

        n = len(start_times)
        curr_count = 0  # current number of people
        start_times_index = 0
        end_times_index = 0
        starts_done = False  # whether all the start times have been added
        ends_done = False  # whether all the end times have been added
        while not (starts_done and ends_done):
            # determine whether a start time or an end time should be added next
            # store this in the variable take_start which is true if a start
            # time should be added next
            if starts_done:
                # already added all the start times; add an end time
                take_start = False
            elif ends_done:
                # already added all the end times; add a start time
                take_start = True
            else:
                # didn't add all the end times nor all the start times
                # add the time that is earliest
                next_start_time = start_times[start_times_index]
                next_end_time = end_times[end_times_index]
                take_start = next_start_time < next_end_time

            if take_start:
                # add next start
                curr_count += 1
                start_time = start_times[start_times_index]
                self._counts[start_time] = curr_count
                start_times_index += 1
                if start_times_index == n:
                    starts_done = True
            else:
                # add next end
                curr_count -= 1
                end_time = end_times[end_times_index]
                self._counts[end_time] = curr_count
                end_times_index += 1
                if end_times_index == n:
                    ends_done = True

    def n_unique_people_saw(self, p):
        """Returns the number of unique people that a pedestrian sees"""
        n = 0
        for q in self.pedestrians:
            if p.intersects(q):
                n += 1
        return n

    def n_people_saw(self, p):
        """Returns the number of times a pedestrian sees someone"""
        n = 0
        for q in self.pedestrians:
            if p.end_time > q.start_time and p.start_time < q.end_time:
                n += p.n_intersections(q)
        return n

    def intersection_directions(self, p):
        """Returns the number of people seen going in the same direction and the
        number of people seen going in the opposite direction by p as a tuple"""
        n_same, n_diff = 0, 0
        for q in self.pedestrians:
            if p.end_time > q.start_time and p.start_time < q.end_time:
                d = q.intersection_direction(p)
                if d == 1:
                    n_same += 1
                elif d == -1:
                    n_diff += 1
        return n_same, n_diff

    def intersection_directions_total(self, p):
        n_same, n_diff = 0, 0
        for q in self.pedestrians:
            if p.end_time > q.start_time and p.start_time < q.end_time:
                i = p.total_intersection_direction(q)
                if i < 0:
                    n_diff += -i
                elif i > 0:
                    n_same += i
        return n_same, n_diff

    def n_people(self, t):
        """Returns the number of people at a given time"""

        if not self._counts_are_correct:
            self._recompute_counts()
            self._counts_are_correct = True

        if t in self._counts:
            return self._counts[t]
        elif t < self.start_time:
            return 0
        else:
            index = self._counts.bisect_left(t)
            return self._counts.values()[index - 1]

    def num_pedestrians(self):
        """Returns the total number of pedestrians in the simulation"""
        return len(self.pedestrians)

    def get_pedestrians_in_interval(self, start, stop):
        """Returns a list of all the pedestrians who entered in the interval
        [start, stop]"""
        return list(self.pedestrians.irange_key(start, stop))

    def num_entrances_in_interval(self, start, stop):
        """Returns the number of pedestrians who entered in the given interval
        of time [start, stop]"""
        return len(self.get_pedestrians_in_interval(start, stop))
    
    def get_enter_and_exit_times_in_interval(self, start, stop):
        """Returns the entrance and exit times in a given time interval
        as a tuple of lists (entrance_times, exit_times)."""
        start_times = []
        end_times = []
        for p in self.pedestrians:
            if start <= p.start_time <= stop:
                start_times.append(p.start_time)
            if start <= p.end_time <= stop:
                end_times.append(p.end_time)
        return start_times, end_times
    
    def get_pedestrians_at_time(self, t):
        """Returns a list of all the pedestrians who were there at time t"""
        # get all pedestrians who entered at or before time t
        entered_before_t = self.pedestrians.irange_key(
            min_key=None, max_key=t, inclusive=(True, True))
        # Of those, return return the ones who exited after time t
        return [p for p in entered_before_t if p.end_time > t]
Exemplo n.º 28
0
class Book(object):
    '''| TODO
    | Keep track of the active orders. It is constructed by using unordered_map, map, and vector data-structures.
    | Unordered map is used to keep pointers to all active orders. In this implementation, it is used to check whether an order already
    | exists in the book. Sorted maps are used to represent the bid and ask depths of the book using the price as a key. For efficiency,
    | the price is represented as (scaled) uint64_t. The insert operation inserts the element at the correct place implementing the
    | price priority in the book. Each element of the maps is a price level (see Level.hpp). Note that the best bid is the last element, i.e.,
    | the last price level, of the bid map; best ask is the first element (price level) of the ask map.
    |________'''
    def __init__(self):
        self.bid = SortedDict(
            neg)  # Key is price as int, value is Level (in descending order)
        self.ask = SortedDict()  # Key is price as int, value is Level

        # Uniqueness of the keys is guaranteed only for active orders, i.e., if an order is removed, another order with the same key can be added
        self.activeOrders = {
        }  # Unordered map of active orders; Key is order Id, value is Order. Used for quick search of orders.
        # Otherwise, we need to iterate over the levels (bid and ask) and to check the orders for the orderId in question

    def isBidEmpty(self):
        return len(self.bid) == 0

    def isAskEmpty(self):
        return len(self.ask) == 0

    def isEmpty(self):
        return len(self.activeOrders) == 0

    def isPresent(self, anOrderId):
        return anOrderId in self.activeOrders

    def addOrder(self,
                 isBuy=None,
                 orderId=None,
                 price=None,
                 qty=None,
                 peakSize=None,
                 order=None):
        '''| TODO 
        | Creates and adds an order to the map of orders. In addition, pointer to the order is added to the proper map (bid/ask) and
        | vector (price level). The maps are ordered, therefore, inserting elements with price as keys, automatically builds a correct
        | depth of the book.
        | Note: best bid is the last element (price level) of the bid map; best ask is the first element (price level) of the ask map.
        |________'''

        # Already checked that an order with the same Id is not present in the book
        myOrder = Order(isBuy, orderId, price, qty,
                        peakSize) if order == None else order

        self.activeOrders[myOrder.orderId] = myOrder

        # TODO: Where do we deal with int*100 price as keys?
        key_price = int(myOrder.price * 100)

        level = self.bid if myOrder.isBuy else self.ask

        if key_price not in level:
            level[key_price] = Level()

        level[key_price].addOrder(myOrder)

    def removeOrder(self, orderId):
        '''| TODO 
        | Removes an active order from the book if present (return false if order not found).
        | In case of icebergs, removes both the visible and hidden parts.
        |________'''
        if orderId in self.activeOrders:
            isBuy = self.activeOrders[orderId].isBuy
            key_price = int(self.activeOrders[orderId].price * 100)

            level = self.bid if isBuy else self.ask
            level[key_price].remove(orderId)

            del self.activeOrders[orderId]
            return True

        return False

    def removeActiveOrder(self, orderId):
        '''| TODO 
        |________'''
        if orderId in self.activeOrders:
            del self.activeOrders[orderId]
            return True

        return False

    def removeEmptyLevels(self):
        '''| TODO 
        | If an incoming order executes and matches with all active orders of the best level
        | including visible and invisible part of the orders, the level is considered empty
        | and the matching continues with the next price level. After the execution, before
        | adding an order and processing a new incoming order, this function is used to remove
        | all empty levels. The book state is updated with new best (bid/ask) levels.
        |________'''
        for price in self.bid.keys():
            if self.bid[price].isEmpty():
                del self.bid[price]

        for price in self.ask.keys():
            if self.ask[price].isEmpty():
                del self.ask[price]

    def clear(self):
        self.activeOrders.clear()
        self.bid.clear()
        self.ask.clear()

    def show(self):
        '''| TODO 
        | Called as a result of command 's'
        | Since the best price is listed first, and the maps used to store the levels are ordered,
        | this function outputs the bid levels by traversing the bid map in reverse (highest price first)
        |________'''

        if self.isEmpty():
            print "Book --- EMPTY ---"

        else:
            if len(self.bid) == 0:
                print "Bid depth --- EMPTY ---"
            else:
                print "Bid depth (highest priority at top):"
                print "Price     ", "Order Id  ", "Quantity  ", "Iceberg"

                # Highest price first
                for _, level in self.bid.iteritems():
                    level.show()
                print

            if len(self.ask) == 0:
                print "Ask depth --- EMPTY ---"
            else:
                print "Ask depth (highest priority at top):"
                print "Price     ", "Order Id  ", "Quantity  ", "Iceberg"

                # Lowest price first
                for _, level in self.ask.iteritems():
                    level.show()
        print
Exemplo n.º 29
0
class Controller(object):
    def __init__(self, thrift_server, thrift_port):
        self.transport = TSocket.TSocket(thrift_server, thrift_port)
        self.transport = TTransport.TBufferedTransport(self.transport)
        bprotocol = TBinaryProtocol.TBinaryProtocol(self.transport)
        conn_mgr_protocol = TMultiplexedProtocol.TMultiplexedProtocol(
            bprotocol, "conn_mgr")
        self.conn_mgr = conn_mgr_pd_rpc.conn_mgr.Client(conn_mgr_protocol)
        p4_protocol = TMultiplexedProtocol.TMultiplexedProtocol(
            bprotocol, "pegasus")
        self.pegasus = pegasus.p4_pd_rpc.pegasus.Client(p4_protocol)
        self.devport = devport_mgr_pd_rpc.devport_mgr.Client(conn_mgr_protocol)
        self.transport.open()

        self.sess_hdl = self.conn_mgr.client_init()
        self.dev = 0
        self.dev_tgt = DevTarget_t(self.dev, hex_to_i16(0xFFFF))
        self.flags = pegasus_register_flags_t(read_hw_sync=True)

        # keyhash -> ReplicatedKey (sorted in ascending load)
        self.replicated_keys = SortedDict(
            lambda x: self.replicated_keys[x].load)
        self.num_nodes = DEFAULT_NUM_NODES
        self.num_rkeys = MAX_NRKEYS
        self.switch_lock = threading.Lock()

    def install_table_entries(self, tables):
        # tab_l2_forward
        self.pegasus.tab_l2_forward_set_default_action__drop(
            self.sess_hdl, self.dev_tgt)
        for (mac, port) in tables["tab_l2_forward"].items():
            self.pegasus.tab_l2_forward_table_add_with_l2_forward(
                self.sess_hdl, self.dev_tgt,
                pegasus_tab_l2_forward_match_spec_t(
                    ethernet_dstAddr=macAddr_to_string(mac)),
                pegasus_l2_forward_action_spec_t(action_port=port))
        # tab_node_forward
        self.pegasus.tab_node_forward_set_default_action__drop(
            self.sess_hdl, self.dev_tgt)
        self.num_nodes = len(tables["tab_node_forward"])
        for (node, attrs) in tables["tab_node_forward"].items():
            self.pegasus.tab_node_forward_table_add_with_node_forward(
                self.sess_hdl, self.dev_tgt,
                pegasus_tab_node_forward_match_spec_t(meta_node=int(node)),
                pegasus_node_forward_action_spec_t(
                    action_mac_addr=macAddr_to_string(attrs["mac"]),
                    action_ip_addr=ipv4Addr_to_i32(attrs["ip"]),
                    action_udp_addr=attrs["udp"],
                    action_port=attrs["port"]))
        # reg_n_servers
        self.pegasus.register_write_reg_n_servers(self.sess_hdl, self.dev_tgt,
                                                  0, self.num_nodes)
        # tab_calc_rset_index
        for i in range(self.num_rkeys):
            self.pegasus.tab_calc_rset_index_table_add_with_calc_rset_index(
                self.sess_hdl, self.dev_tgt,
                pegasus_tab_calc_rset_index_match_spec_t(meta_rkey_index=i),
                pegasus_calc_rset_index_action_spec_t(action_base=i *
                                                      MAX_RSET_SIZE))
        # tab_replicated_keys
        self.all_rkeys = tables["tab_replicated_keys"]
        for i in range(self.num_rkeys):
            self.add_rkey(self.all_rkeys[i], i, 0)
        self.conn_mgr.complete_operations(self.sess_hdl)

    def reset(self):
        self.switch_lock.acquire()
        for keyhash in self.replicated_keys.keys():
            self.pegasus.tab_replicated_keys_table_delete_by_match_spec(
                self.sess_hdl, self.dev_tgt,
                pegasus_tab_replicated_keys_match_spec_t(
                    pegasus_keyhash=int(keyhash)))
        self.replicated_keys.clear()
        for i in range(self.num_rkeys):
            self.add_rkey(self.all_rkeys[i], i, 0)
        self.pegasus.register_write_reg_ver_next(self.sess_hdl, self.dev_tgt,
                                                 0, 1)
        self.pegasus.register_write_reg_n_servers(self.sess_hdl, self.dev_tgt,
                                                  0, self.num_nodes)
        self.pegasus.register_write_reg_rr_all_servers(self.sess_hdl,
                                                       self.dev_tgt, 0, 0)
        self.conn_mgr.complete_operations(self.sess_hdl)
        self.switch_lock.release()

    def add_rkey(self, keyhash, rkey_index, load):
        if RSET_ALL:
            self.pegasus.register_write_reg_rset_size(self.sess_hdl,
                                                      self.dev_tgt, rkey_index,
                                                      self.num_nodes)
            bitmap = (2**self.num_nodes) - 1
            self.pegasus.register_write_reg_rset_bitmap(
                self.sess_hdl, self.dev_tgt, rkey_index, bitmap)
            rset_index = rkey_index << RSET_INDEX_SHIFT
            for i in range(self.num_nodes):
                self.pegasus.register_write_reg_rset(self.sess_hdl,
                                                     self.dev_tgt,
                                                     rset_index + i, i)
        else:
            node = int(keyhash) % self.num_nodes
            bitmap = 1 << node
            rset_index = rkey_index << RSET_INDEX_SHIFT
            self.pegasus.register_write_reg_rset_size(self.sess_hdl,
                                                      self.dev_tgt, rkey_index,
                                                      1)
            self.pegasus.register_write_reg_rset_bitmap(
                self.sess_hdl, self.dev_tgt, rkey_index, bitmap)
            self.pegasus.register_write_reg_rset(self.sess_hdl, self.dev_tgt,
                                                 rset_index, node)
        self.pegasus.register_write_reg_rkey_ver_completed(
            self.sess_hdl, self.dev_tgt, rkey_index, 1)
        self.pegasus.register_write_reg_rkey_read_counter(
            self.sess_hdl, self.dev_tgt, rkey_index, 0)
        self.pegasus.register_write_reg_rkey_write_counter(
            self.sess_hdl, self.dev_tgt, rkey_index, 0)
        self.pegasus.register_write_reg_rkey_rate_counter(
            self.sess_hdl, self.dev_tgt, rkey_index, 0)
        self.pegasus.register_write_reg_rr_rkey(self.sess_hdl, self.dev_tgt,
                                                rkey_index, 0)
        self.pegasus.tab_replicated_keys_table_add_with_is_rkey(
            self.sess_hdl, self.dev_tgt,
            pegasus_tab_replicated_keys_match_spec_t(
                pegasus_keyhash=int(keyhash)),
            pegasus_is_rkey_action_spec_t(action_rkey_index=rkey_index))
        self.replicated_keys.setdefault(
            keyhash, ReplicatedKey(index=int(rkey_index), load=load))

    def periodic_update(self):
        self.switch_lock.acquire()
        # Reset read and write counters
        self.pegasus.register_reset_all_reg_rkey_read_counter(
            self.sess_hdl, self.dev_tgt)
        self.pegasus.register_reset_all_reg_rkey_write_counter(
            self.sess_hdl, self.dev_tgt)
        # Read rkey load
        for rkey in self.replicated_keys.values():
            read_value = self.pegasus.register_read_reg_rkey_rate_counter(
                self.sess_hdl, self.dev_tgt, rkey.index, self.flags)
            rkey.load = int(read_value[1])
        # Reset rkey load
        self.pegasus.register_reset_all_reg_rkey_rate_counter(
            self.sess_hdl, self.dev_tgt)
        self.conn_mgr.complete_operations(self.sess_hdl)
        self.switch_lock.release()

    def print_stats(self):
        self.switch_lock.acquire()
        # read replicated keys info
        for (keyhash, rkey) in self.replicated_keys.items():
            read_value = self.pegasus.register_read_reg_rkey_rate_counter(
                self.sess_hdl, self.dev_tgt, rkey.index, self.flags)
            rkey.load = read_value[0]
            print "rkey hash", keyhash
            print "rkey load", rkey.load
            read_value = self.pegasus.register_read_reg_rkey_ver_completed(
                self.sess_hdl, self.dev_tgt, rkey.index, self.flags)
            print "ver completed", read_value[0]
            read_value = self.pegasus.register_read_reg_rset_size(
                self.sess_hdl, self.dev_tgt, rkey.index, self.flags)
            rset_size = int(read_value[0])
            print "rset size", rset_size
            read_value = self.pegasus.register_read_reg_rset_bitmap(
                self.sess_hdl, self.dev_tgt, rkey.index, self.flags)
            print "rset bitmap", read_value[0]
            base = rkey.index * MAX_RSET_SIZE
            for i in range(rset_size):
                read_value = self.pegasus.register_read_reg_rset(
                    self.sess_hdl, self.dev_tgt, base + i, self.flags)
                print "replica", read_value[0]
        self.conn_mgr.complete_operations(self.sess_hdl)
        self.switch_lock.release()

    def run(self):
        while True:
            time.sleep(0.1)
            self.periodic_update()

    def stop(self):
        self.transport.close()
Exemplo n.º 30
0
class OrderBookManager(Initializable):
    def __init__(self):
        super().__init__()
        self.logger = get_logger(self.__class__.__name__)
        self.order_book_initialized = False
        self.asks = SortedDict()
        self.bids = SortedDict()
        self.timestamp = 0
        self.ask_quantity, self.ask_price, self.bid_quantity, self.bid_price = 0, 0, 0, 0

    async def initialize_impl(self):
        self.reset_order_book()

    def reset_order_book(self):
        self.order_book_initialized = False
        self.asks.clear()
        self.bids.clear()
        self.timestamp = 0
        self.ask_quantity, self.ask_price, self.bid_quantity, self.bid_price = 0, 0, 0, 0

    def order_book_ticker_update(self, ask_quantity, ask_price, bid_quantity,
                                 bid_price):
        self.ask_quantity, self.ask_price = ask_quantity, ask_price
        self.bid_quantity, self.bid_price = bid_quantity, bid_price

    def handle_new_book(self, orders):
        try:
            self.handle_new_books(asks=orders[ECOBIC.ASKS.value],
                                  bids=orders[ECOBIC.BIDS.value],
                                  timestamp=orders[ECOBIC.TIMESTAMP.value])
        except KeyError:
            self.logger.error("Failed to parse new order book")

    def handle_new_books(self, asks, bids, timestamp=None):
        self.reset_order_book()
        self.handle_book_adds(
            _convert_price_size_list_to_order(asks, TradeOrderSide.SELL.value))
        self.handle_book_adds(
            _convert_price_size_list_to_order(bids, TradeOrderSide.BUY.value))
        if timestamp:
            self.timestamp = timestamp
        self.order_book_initialized = True

    def handle_book_adds(self, orders):
        for order in orders:
            try:
                self._handle_book_add(order)
            except KeyError as e:
                self.logger.error(
                    f"Error when adding order to order_book : {e}")

    def handle_book_deletes(self, orders):
        for order in orders:
            try:
                self._handle_book_delete(order)
            except KeyError as e:
                self.logger.error(
                    f"Error when deleting order from order_book : {e}")

    def handle_book_updates(self, orders):
        for order in orders:
            try:
                self._handle_book_update(order)
            except KeyError as e:
                self.logger.error(
                    f"Error when updating order in order_book : {e}")

    def _handle_book_add(self, order):
        # Add buy side orders
        if order[ECOBIC.SIDE.value] == TradeOrderSide.BUY.value:
            bids = self.get_bids(order[ECOBIC.PRICE.value])
            if bids is None:
                bids = [order]
            else:
                bids.append(order)
            self._set_bids(order[ECOBIC.PRICE.value], bids)
            return

        # Add sell side orders
        asks = self.get_asks(order[ECOBIC.PRICE.value])
        if asks is None:
            asks = [order]
        else:
            asks.append(order)
        self._set_asks(order[ECOBIC.PRICE.value], asks)

    def _handle_book_delete(self, order):
        price = Decimal(order[ECOBIC.PRICE.value])

        # Delete buy side orders
        if order[ECOBIC.SIDE.value] == TradeOrderSide.BUY.value:
            bids = self.get_bids(price)
            if bids is not None:
                bids = [
                    bid_order for bid_order in bids if bid_order[
                        ECOBIC.ORDER_ID.value] != order[ECOBIC.ORDER_ID.value]
                ]
                if len(bids) > 0:
                    self._set_bids(price, bids)
                else:
                    self._remove_bids(price)
            return

        # Delete sell side orders
        asks = self.get_asks(price)
        if asks is not None:
            asks = [
                ask_order for ask_order in asks if ask_order[
                    ECOBIC.ORDER_ID.value] != order[ECOBIC.ORDER_ID.value]
            ]
            if len(asks) > 0:
                self._set_asks(price, asks)
            else:
                self._remove_asks(price)

    def _handle_book_update(self, order):
        size = Decimal(order.get(ECOBIC.SIZE.value, INVALID_PARSED_VALUE))
        price = Decimal(order[ECOBIC.PRICE.value])

        # Update buy side orders
        if order[ECOBIC.SIDE.value] == TradeOrderSide.BUY.value:
            bids = self.get_bids(price)
            order_index = _order_id_index(order[ECOBIC.ORDER_ID.value], bids)
            if bids is None or order_index == ORDER_ID_NOT_FOUND:
                return
            if size != INVALID_PARSED_VALUE:
                bids[order_index][ECOBIC.SIZE.value] = size
            self._set_bids(price, bids)
            return

        # Update sell side orders
        asks = self.get_asks(price)
        order_index = _order_id_index(order[ECOBIC.ORDER_ID.value], asks)
        if asks is None or order_index == ORDER_ID_NOT_FOUND:
            return
        if size != INVALID_PARSED_VALUE:
            asks[order_index][ECOBIC.SIZE.value] = size
        self._set_asks(price, asks)

    def _set_asks(self, price, asks):
        self.asks[price] = asks

    def _set_bids(self, price, bids):
        self.bids[price] = bids

    def _remove_asks(self, price):
        del self.asks[price]

    def _remove_bids(self, price):
        del self.bids[price]

    def get_ask(self):
        return self.asks.peekitem(0)

    def get_bid(self):
        return self.bids.peekitem(-1)

    def get_asks(self, price):
        return self.asks.get(price, None)

    def get_bids(self, price):
        return self.bids.get(price, None)
Exemplo n.º 31
0
class RL():
    def __init__(self, sess, env, n_s, n_a, args):

        if sess is None:
            self.sess = tf.Session()
        else:
            self.sess = sess

        self.args = args
        self.env = env
        self.env.seed(self.args.seed)

        self.n_s = n_s
        self.n_a = n_a

        self.init = True  ### Use to detect is in init state or not, if Yes use random action

        self.ite_count = 0  ### Num of iter

        self.dict = SortedDict()

        self.release = 10
        self.reward_ = 200  ### use to clip reward

        self.save_index = 0  ### Save index

        self.network_model()
        self.saver = tf.compat.v1.train.Saver()
        tf.compat.v1.random.set_random_seed(args.seed)

    def network_model(self):
        def init_weights(input_):
            x = 1 / (np.sqrt(input_))
            return tf.compat.v1.random_uniform_initializer(-x,
                                                           x,
                                                           seed=self.args.seed)

        def behavior_build_network():
            w_init = tf.compat.v1.initializers.variance_scaling(
                scale=np.sqrt(2 / (1 + np.sqrt(5)**2)),
                distribution='uniform',
                seed=self.args.seed)

            l1 = tf.layers.dense(
                inputs=self.input_,
                units=self.args.hidden_units,
                kernel_initializer=w_init,
                bias_initializer=init_weights(self.n_s),
                activation='sigmoid',
                name='l1',
            )
            c1 = tf.layers.dense(inputs=tf.concat((self.d_r, self.d_h), 1),
                                 units=self.args.hidden_units,
                                 kernel_initializer=w_init,
                                 bias_initializer=init_weights(
                                     self.args.hidden_units),
                                 activation='sigmoid',
                                 name='c_out')

            out_1 = tf.math.multiply(l1, c1)

            l2 = tf.layers.dense(
                inputs=out_1,
                units=self.n_a,
                activation=None,
                kernel_initializer=w_init,
                bias_initializer=init_weights(self.args.hidden_units),
                name='l2',
            )

            # b=tf.layers.dense(
            # 	inputs=l2,
            # 	units=self.n_a,
            # 	kernel_initializer=w_init,
            # 	bias_initializer=init_weights(self.args.hidden_units),
            # 	activation=None,
            # 	name='out'
            # 	)
            b = l2
            return b

        ### 										ALL input
        self.input_ = tf.compat.v1.placeholder(tf.float32, [None, self.n_s],
                                               'input_')

        self.c_in = tf.compat.v1.placeholder(tf.float32, [None, 2], 'c_in')
        self.d_h = tf.compat.v1.placeholder(tf.float32, [None, 1], 'd_h')
        self.d_r = tf.compat.v1.placeholder(tf.float32, [None, 1], 'd_r')

        self.a = tf.compat.v1.placeholder(tf.int32, [
            None,
        ], 'action')

        with tf.compat.v1.variable_scope('behavior_function'):
            self.b = behavior_build_network()
            self.b_softmax = tf.nn.softmax(self.b)
            self.a_out = tf.squeeze(
                tf.random.categorical(logits=self.b,
                                      num_samples=1,
                                      seed=self.args.seed))

        with tf.compat.v1.variable_scope('loss'):
            self.loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.b,
                                                               labels=self.a))

        with tf.compat.v1.variable_scope('train'):
            self.train_op = tf.compat.v1.train.AdamOptimizer(
                self.args.lr).minimize(self.loss)

    def action_choice(self, s, c, dr, dh):
        s = np.asarray(s, dtype=np.float32).reshape((1, self.n_s))
        dr = np.asarray(dr).reshape((-1, 1))
        dh = np.asarray(dh).reshape((-1, 1))
        action = self.sess.run(self.a_out, {
            self.input_: s,
            self.d_r: dr,
            self.d_h: dh
        })
        return action

    def get_(self):
        if self.init:
            self.desire_r_init, self.desire_h_init = 0, 0
            return
        h = []
        r = []

        for _ in range(self.args.generate_per_single_training):
            epoides = self.dict.popitem()
            h.append(len(epoides[1][0]))
            r.append(epoides[0])

        seleceted_eposide_len = np.mean(h)
        seleceted_eposide_mean = np.random.uniform(low=np.mean(r),
                                                   high=(np.mean(r) +
                                                         np.std(r)))
        self.desire_r_init, self.desire_h_init = seleceted_eposide_mean, seleceted_eposide_len

    def feed(self):
        self.get_()
        self.dict.clear()
        for _ in range(self.args.memory_thersold):
            state, action, reward, total_reward = self.play()
            self.dict.__setitem__(total_reward, (state, action, reward))
        self.init = False

    def play(self):
        s = self.env.reset()
        if self.ite_count == 0:
            self.sess.run(tf.compat.v1.global_variables_initializer())

        state_list = []
        action_list = []
        reward_list = []

        reward_total = 0
        done = False

        desire_h = self.desire_h_init
        desire_r = self.desire_r_init

        while not done:
            c = np.asarray([desire_h, desire_r])

            if self.init:
                a = np.random.randint(self.n_a)
            else:
                a = self.action_choice(s, c, desire_r, desire_h)

            s_, r, done, _ = self.env.step(a)

            state_list.append(s)
            action_list.append(a)
            reward_list.append(r)
            reward_total += r

            desire_h = max(desire_h - 1, 1)
            desire_r = min(desire_r - r, self.reward_)

            s = s_

            if done:
                break
        return state_list, action_list, reward_list, reward_total

    def learn(self):
        if self.ite_count == 0:
            self.sess.run(tf.compat.v1.global_variables_initializer())

        memory_dic = dict(self.dict)
        dic_value = list(memory_dic.values())

        for _ in range(self.args.n_update_eposide):
            state = []
            dr = []
            dh = []
            true_a = []
            c = []
            indices = np.random.choice(
                len(dic_value), self.args.batch_size,
                replace=True)  ######### random sample which eposide will use.
            tran = [dic_value[i] for i in indices]
            random_index = [np.random.choice(len(e[0]) - 2, 1) for e in tran]
            for idx_, tran_ in zip(random_index, tran):
                state.append(tran_[0][idx_[0]])
                dr.append(np.sum(tran_[2][idx_[0]:]))
                dh.append(len(tran_[0]) - idx_[0])
                true_a.append(tran_[1][idx_[0]])
                c.append([np.sum(tran_[2][idx_[0]:]), len(tran_[0]) - idx_[0]])

            command_ = np.asarray(c, dtype=np.float32).reshape(-1, 2)
            s_t = np.asarray(state, dtype=np.float32)
            action = np.asarray([a_ for a_ in true_a])
            dr = np.asarray(dr, dtype=np.float32).reshape((-1, 1))
            dh = np.asarray(dh, dtype=np.float32).reshape((-1, 1))
            _, loss = self.sess.run(
                [self.train_op, self.loss], {
                    self.input_: s_t,
                    self.c_in: command_,
                    self.a: action,
                    self.d_r: dr,
                    self.d_h: dh
                })

    def eval(self, eval_ite):
        test_reward = []
        test_step = []
        for i in range(self.args.eval_step):
            _, _, r_list, total_reward = self.play()
            test_reward.append(total_reward)
            test_step.append(len(r_list))
        print('ite: {},   reward: {:.3f},'.format(eval_ite,
                                                  np.mean(test_reward)))
        return np.mean(test_reward)

    def train(self):
        self.feed()
        test = []
        print(
            '----------------using tensorflow with {} generate_step_per_single_training----------------'
            .format(self.args.generate_per_single_training))
        while True:
            self.learn()
            self.ite_count += 1
            self.feed()
            if (self.ite_count - 1) % self.args.eval_step_every_k_step == 0:
                score = self.eval(self.ite_count - 1)
                test.append(score)
                if len(test) % self.release == 0 or (self.ite_count - 1) == 0:
                    # 					self.saver.save(self.sess,r'C:\Users\USER\Desktop\Upside down\new folder\result\memory_thersold\tensorflow_model_{}_tensorflow_categorical_1.ckpt'.format(self.args.generate_per_single_training))
                    self.saver.save(
                        self.sess, self.args.save_path + '\\' +
                        'tensorflow_model_{}_tensorflow_categorical_1.ckpt'.
                        format(self.args.generate_per_single_training))
                    # 					np.save(r'C:\Users\USER\Desktop\Upside down\new folder\result\memory_thersold\tensorflow_reward_test_{}_{}__tensorflow_categorical_1.npy'.format(self.save_index,self.args.generate_per_single_training),test)
                    print('saved')
                    self.save_index += 1
                    test = []
                    print((time.time() - start_) / 60)
                    start_ = time.time()
Exemplo n.º 32
0
class LeafSet(object):
    __slots__ = ('peers', 'capacity')
    __passthru = {'get', 'clear', 'pop', 'popitem', 'peekitem', 'key'}
    __iters = {'keys', 'values', 'items'}

    def __init__(self, my_key, iterable=(), capacity=8):
        try:
            iterable = iterable.items()  # view object
        except AttributeError:
            pass
        tuple_itemgetter = Peer.distance(my_key, itemgetter(0))
        key_itemgetter = Peer.distance(my_key)
        self.capacity = capacity
        self.peers = SortedDict(key_itemgetter)
        if iterable:
            l = sorted(iterable, key=tuple_itemgetter)
            self.peers.update(islice(l, capacity))

    def clear(self):
        self.peers.clear()

    def prune(self):
        extra = len(self) - self.capacity
        for i in range(extra):
            self.peers.popitem(last=True)

    def update(self, iterable):
        try:
            iterable = iterable.items()  # view object
        except AttributeError:
            pass
        iterable = iter(iterable)
        items = tuple(islice(iterable, 500))
        while items:
            self.peers.update(items)
            items = tuple(islice(iterable, 500))


    def setdefault(self, *args, **kwargs):
        self.peers.setdefault(*args, **kwargs)
        self.prune()

    def __setitem__(self, *args, **kwargs):
        self.peers.__setitem__(*args, **kwargs)
        self.prune()

    def __getitem__(self, *args, **kwargs):
        return self.peers.__getitem__(*args, **kwargs)

    def __delitem__(self, *args, **kwargs):
        return self.peers.__delitem__(*args, **kwargs)

    def __iter__(self, *args, **kwargs):
        return self.peers.__iter__(*args, **kwargs)

    def __reversed__(self, *args, **kwargs):
        return self.peers.__reversed__(*args, **kwargs)

    def __contains__(self, *args, **kwargs):
        return self.peers.__contains__(*args, **kwargs)

    def __len__(self, *args, **kwargs):
        return self.peers.__len__(*args, **kwargs)

    def __getattr__(self, key):
        if key in self.__class__.__passthru:
            return getattr(self.peers, key)
        elif key in self.__class__.__iters:
            return getattr(self.peers, 'iter' + key)
        else:
            return super().__getattr__(key)

    def __repr__(self):
        return '<%s keys=%r capacity=%d/%d>' % (
            self.__class__.__name__, list(self), len(self), self.capacity)
Exemplo n.º 33
0
class OrderedDict(dict):
    """Dictionary that remembers insertion order and is numerically indexable.

    Keys are numerically indexable using dict views. For example::

        >>> ordered_dict = OrderedDict.fromkeys('abcde')
        >>> keys = ordered_dict.keys()
        >>> keys[0]
        'a'
        >>> keys[-2:]
        ['d', 'e']

    The dict views support the sequence abstract base class.

    """

    # pylint: disable=super-init-not-called
    def __init__(self, *args, **kwargs):
        self._keys = {}
        self._nums = SortedDict()
        self._keys_view = self._nums.keys()
        self._count = count()
        self.update(*args, **kwargs)

    def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
        "``ordered_dict[key] = value``"
        if key not in self:
            num = next(self._count)
            self._keys[key] = num
            self._nums[num] = key
        dict_setitem(self, key, value)

    def __delitem__(self, key, dict_delitem=dict.__delitem__):
        "``del ordered_dict[key]``"
        dict_delitem(self, key)
        num = self._keys.pop(key)
        del self._nums[num]

    def __iter__(self):
        "``iter(ordered_dict)``"
        return iter(self._nums.values())

    def __reversed__(self):
        "``reversed(ordered_dict)``"
        nums = self._nums
        for key in reversed(nums):
            yield nums[key]

    def clear(self, dict_clear=dict.clear):
        "Remove all items from mapping."
        dict_clear(self)
        self._keys.clear()
        self._nums.clear()

    def popitem(self, last=True):
        """Remove and return (key, value) item pair.

        Pairs are returned in LIFO order if last is True or FIFO order if
        False.

        """
        index = -1 if last else 0
        num = self._keys_view[index]
        key = self._nums[num]
        value = self.pop(key)
        return key, value

    update = __update = co.MutableMapping.update

    def keys(self):
        "Return set-like and sequence-like view of mapping keys."
        return KeysView(self)

    def items(self):
        "Return set-like and sequence-like view of mapping items."
        return ItemsView(self)

    def values(self):
        "Return set-like and sequence-like view of mapping values."
        return ValuesView(self)

    def pop(self, key, default=NONE):
        """Remove given key and return corresponding value.

        If key is not found, default is returned if given, otherwise raise
        KeyError.

        """
        if key in self:
            value = self[key]
            del self[key]
            return value
        elif default is NONE:
            raise KeyError(key)
        else:
            return default

    def setdefault(self, key, default=None):
        """Return ``mapping.get(key, default)``, also set ``mapping[key] = default`` if
        key not in mapping.

        """
        if key in self:
            return self[key]
        self[key] = default
        return default

    @recursive_repr()
    def __repr__(self):
        "Text representation of mapping."
        return '%s(%r)' % (self.__class__.__name__, list(self.items()))

    __str__ = __repr__

    def __reduce__(self):
        "Support for pickling serialization."
        return (self.__class__, (list(self.items()), ))

    def copy(self):
        "Return shallow copy of mapping."
        return self.__class__(self)

    @classmethod
    def fromkeys(cls, iterable, value=None):
        """Return new mapping with keys from iterable.

        If not specified, value defaults to None.

        """
        return cls((key, value) for key in iterable)

    def __eq__(self, other):
        "Test self and other mapping for equality."
        if isinstance(other, OrderedDict):
            return dict.__eq__(self, other) and all(map(eq, self, other))
        return dict.__eq__(self, other)

    __ne__ = co.MutableMapping.__ne__

    def _check(self):
        "Check consistency of internal member variables."
        # pylint: disable=protected-access
        keys = self._keys
        nums = self._nums

        for key, value in keys.items():
            assert nums[value] == key

        nums._check()
Exemplo n.º 34
0
class UpsideDownRL(object):
    def __init__(self, env, args):
        super(UpsideDownRL, self).__init__()
        self.env = env
        self.args = args
        self.nb_actions = self.env.action_space.n
        self.state_space = self.env.observation_space.shape[0]

        # Use sorted dict to store experiences gathered.
        # This helps in fetching highest reward trajectories during exploratory stage.
        self.experience = SortedDict()
        self.B = BehaviorFunc(self.state_space, self.nb_actions, args).cuda()
        self.optimizer = optim.Adam(self.B.parameters(), lr=self.args.lr)
        self.use_random_actions = True  # True for the first training epoch.
        self.softmax = nn.Softmax()
        # Used to clip rewards so that B does not get unrealistic expected reward inputs.
        self.lunar_lander_max_reward = 250

    # Generate an episode using given command inputs to the B function.
    def gen_episode(self, dr, dh):
        state = self.env.reset()
        episode_data = []
        states = []
        rewards = []
        actions = []
        total_reward = 0
        while True:
            action = self.select_action(state, dr, dh)
            next_state, reward, is_terminal, _ = self.env.step(action)
            if self.args.render:
                self.env.render()
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            total_reward += reward
            state = next_state
            dr = min(dr - reward, self.lunar_lander_max_reward)
            dh = max(dh - 1, 1)
            if is_terminal:
                break

        return total_reward, states, actions, rewards

    # Fetch the desired return and horizon from the best trajectories in the current replay buffer
    # to sample more trajectories using the latest behavior function.
    def fill_replay_buffer(self):
        dr, dh = self.get_desired_return_and_horizon()
        self.experience.clear()
        for i in range(self.args.replay_buffer_capacity):
            total_reward, states, actions, rewards = self.gen_episode(dr, dh)
            self.experience.__setitem__(total_reward,
                                        (states, actions, rewards))

        if self.args.verbose:
            if self.use_random_actions:
                print("Filled replay buffer with random actions")
            else:
                print("Filled replay buffer using BehaviorFunc")
        self.use_random_actions = False

    def select_action(self, state, desired_return=None, desired_horizon=None):
        if self.use_random_actions:
            action = np.random.randint(self.nb_actions)
        else:
            action_prob = self.B(
                torch.from_numpy(state).cuda(),
                torch.from_numpy(np.array(desired_return,
                                          dtype=np.float32)).reshape(-1,
                                                                     1).cuda(),
                torch.from_numpy(
                    np.array(desired_horizon,
                             dtype=np.float32).reshape(-1, 1)).cuda())
            action_prob = self.softmax(action_prob)
            # create a categorical distribution over action probabilities
            dist = Categorical(action_prob)
            action = dist.sample().item()
        return action

    # Todo: don't popitem from the experience buffer since these best-performing trajectories can have huge impact on learning of B
    def get_desired_return_and_horizon(self):
        if (self.use_random_actions):
            return 0, 0

        h = []
        r = []
        for i in range(self.args.explore_buffer_len):
            episode = self.experience.popitem()  # will return in sorted order
            h.append(len(episode[1][0]))
            r.append(episode[0])

        mean_horizon_len = np.mean(h)
        mean_reward = np.random.uniform(low=np.mean(r),
                                        high=np.mean(r) + np.std(r))
        return mean_reward, mean_horizon_len

    def trainBehaviorFunc(self):
        experience_dict = dict(self.experience)
        experience_values = list(experience_dict.values())
        for i in range(self.args.train_iter):
            state = []
            dr = []
            dh = []
            target = []
            indices = np.random.choice(len(experience_values),
                                       self.args.batch_size,
                                       replace=True)
            train_episodes = [experience_values[i] for i in indices]
            t1 = [np.random.choice(len(e[0]) - 2, 1) for e in train_episodes]

            for pair in zip(t1, train_episodes):
                state.append(pair[1][0][pair[0][0]])
                dr.append(np.sum(pair[1][2][pair[0][0]:]))
                dh.append(len(pair[1][0]) - pair[0][0])
                target.append(pair[1][1][pair[0][0]])

            self.optimizer.zero_grad()
            state = torch.from_numpy(np.array(state)).cuda()
            dr = torch.from_numpy(
                np.array(dr, dtype=np.float32).reshape(-1, 1)).cuda()
            dh = torch.from_numpy(
                np.array(dh, dtype=np.float32).reshape(-1, 1)).cuda()
            target = torch.from_numpy(np.array(target)).long().cuda()
            action_logits = self.B(state, dr, dh)
            loss = nn.CrossEntropyLoss()
            output = loss(action_logits, target).mean()
            output.backward()
            self.optimizer.step()

    # Evaluate the agent using the initial command input from the best topK performing trajectories.
    def evaluate(self):
        testing_rewards = []
        testing_steps = []
        dr, dh = self.get_desired_return_and_horizon()
        for i in range(self.args.evaluate_trials):
            total_reward, states, actions, rewards = self.gen_episode(dr, dh)
            testing_rewards.append(total_reward)
            testing_steps.append(len(rewards))

        print("Mean reward achieved : {}".format(np.mean(testing_rewards)))
        return np.mean(testing_rewards)

    def train(self):
        # Fill replay buffer with random actions for the first time.
        self.fill_replay_buffer()
        iterations = 0
        test_returns = []
        while True:
            # Train behavior function with trajectories stored in the replay buffer.
            self.trainBehaviorFunc()
            self.fill_replay_buffer()

            if iterations % self.args.eval_every_k_epoch == 0:
                test_returns.append(self.evaluate())
                torch.save(self.B.state_dict(),
                           os.path.join(self.args.save_path, "model.pkl"))
                np.save(os.path.join(self.args.save_path, "testing_rewards"),
                        test_returns)
            iterations += 1
Exemplo n.º 35
0
class NodeRegHandler(BatchRequestHandler, WriteRequestHandler):
    def __init__(self, database_manager: DatabaseManager):
        BatchRequestHandler.__init__(self, database_manager, POOL_LEDGER_ID)
        WriteRequestHandler.__init__(self, database_manager, NODE,
                                     POOL_LEDGER_ID)

        self.uncommitted_node_reg = []
        self.committed_node_reg = []
        self.node_reg_at_beginning_of_view = SortedDict()

        self._uncommitted = deque()  # type: deque[UncommittedNodeReg]
        self._uncommitted_view_no = 0
        self._committed_view_no = 0

    def on_catchup_finished(self):
        self._load_current_node_reg()
        # we must have node regs for at least last two views
        self._load_last_view_node_reg()
        logger.info("Loaded current node registry from the ledger: {}".format(
            self.uncommitted_node_reg))
        logger.info("Current node registry for previous views: {}".format(
            sorted(self.node_reg_at_beginning_of_view.items())))

    def post_batch_applied(self,
                           three_pc_batch: ThreePcBatch,
                           prev_handler_result=None):
        # Observer case:
        if not self.uncommitted_node_reg and three_pc_batch.node_reg:
            self.uncommitted_node_reg = list(three_pc_batch.node_reg)

        view_no = three_pc_batch.view_no if three_pc_batch.original_view_no is None else three_pc_batch.original_view_no
        self._uncommitted.append(
            UncommittedNodeReg(list(self.uncommitted_node_reg), view_no))

        if view_no > self._uncommitted_view_no:
            self.node_reg_at_beginning_of_view[view_no] = list(
                self.uncommitted_node_reg)
            self._uncommitted_view_no = view_no

        three_pc_batch.node_reg = list(self.uncommitted_node_reg)

        logger.debug("Applied uncommitted node registry: {}".format(
            self.uncommitted_node_reg))
        logger.debug("Current node registry for previous views: {}".format(
            sorted(self.node_reg_at_beginning_of_view.items())))

    def post_batch_rejected(self, ledger_id, prev_handler_result=None):
        reverted = self._uncommitted.pop()
        if len(self._uncommitted) == 0:
            self.uncommitted_node_reg = self.committed_node_reg
            self._uncommitted_view_no = self._committed_view_no
        else:
            last_uncommitted = self._uncommitted[-1]
            self.uncommitted_node_reg = last_uncommitted.uncommitted_node_reg
            self._uncommitted_view_no = last_uncommitted.view_no
        if self._uncommitted_view_no < reverted.view_no:
            self.node_reg_at_beginning_of_view.pop(reverted.view_no)

        logger.debug("Reverted uncommitted node registry from {} to {}".format(
            reverted.uncommitted_node_reg, self.uncommitted_node_reg))
        logger.debug("Current node registry for previous views: {}".format(
            sorted(self.node_reg_at_beginning_of_view.items())))

    def commit_batch(self,
                     three_pc_batch: ThreePcBatch,
                     prev_handler_result=None):
        prev_committed = self.committed_node_reg
        self.committed_node_reg = self._uncommitted.popleft(
        ).uncommitted_node_reg
        self._committed_view_no = three_pc_batch.view_no if three_pc_batch.original_view_no is None else three_pc_batch.original_view_no

        # make sure that we have node reg for the current and previous view (which can be less than the current for more than 1)
        # Ex.: node_reg_at_beginning_of_view has views {0, 3, 5, 7, 11, 13), committed is now 7, so we need to keep all uncommitted (11, 13),
        # and keep the one from the previous view (5). Views 0 and 3 needs to be deleted.
        view_nos = list(self.node_reg_at_beginning_of_view.keys())
        prev_committed_index = max(view_nos.index(self._committed_view_no) - 1, 0) \
            if self._committed_view_no in self.node_reg_at_beginning_of_view else 0
        for view_no in view_nos[:prev_committed_index]:
            self.node_reg_at_beginning_of_view.pop(view_no, None)

        if prev_committed != self.committed_node_reg:
            logger.info("Committed node registry: {}".format(
                self.committed_node_reg))
            logger.info("Current node registry for previous views: {}".format(
                sorted(self.node_reg_at_beginning_of_view.items())))
        else:
            logger.debug("Committed node registry: {}".format(
                self.committed_node_reg))
            logger.debug("Current node registry for previous views: {}".format(
                sorted(self.node_reg_at_beginning_of_view.items())))

    def apply_request(self, request: Request, batch_ts, prev_result):
        if request.operation.get(TYPE) != NODE:
            return None, None, None

        node_name = request.operation[DATA][ALIAS]
        services = request.operation[DATA].get(SERVICES)

        if services is None:
            return None, None, None

        if node_name not in self.uncommitted_node_reg and VALIDATOR in services:
            # new node added or old one promoted
            self.uncommitted_node_reg.append(node_name)
            logger.info("Changed uncommitted node registry to: {}".format(
                self.uncommitted_node_reg))
        elif node_name in self.uncommitted_node_reg and VALIDATOR not in services:
            # existing node demoted
            self.uncommitted_node_reg.remove(node_name)
            logger.info("Changed uncommitted node registry to: {}".format(
                self.uncommitted_node_reg))

        return None, None, None

    def update_state(self, txn, prev_result, request, is_committed=False):
        pass

    def static_validation(self, request):
        pass

    def dynamic_validation(self, request):
        pass

    def gen_state_key(self, txn):
        pass

    def _load_current_node_reg(self):
        node_reg = self.__load_current_node_reg_from_audit_ledger()
        if node_reg is None:
            node_reg = self.__load_node_reg_from_pool_ledger()
        self.uncommitted_node_reg = list(node_reg)
        self.committed_node_reg = list(node_reg)

    def _load_last_view_node_reg(self):
        self.node_reg_at_beginning_of_view.clear()

        # 1. check if we have audit ledger at all
        audit_ledger = self.database_manager.get_ledger(AUDIT_LEDGER_ID)
        if not audit_ledger:
            # don't have audit ledger yet, so get aleady loaded values from the pool ledger
            self.node_reg_at_beginning_of_view[0] = list(
                self.uncommitted_node_reg)
            self._uncommitted_view_no = 0
            self._committed_view_no = 0
            return

        # 2. get the first txn in the current view
        first_txn_in_this_view, last_txn_in_prev_view = self.__get_first_txn_in_view_from_audit(
            audit_ledger, audit_ledger.get_last_committed_txn())
        self._uncommitted_view_no = get_payload_data(
            first_txn_in_this_view)[AUDIT_TXN_VIEW_NO]
        self._committed_view_no = self._uncommitted_view_no
        self.node_reg_at_beginning_of_view[self._committed_view_no] = list(
            self.__load_node_reg_for_view(audit_ledger,
                                          first_txn_in_this_view))

        # 4. Check if audit ledger has information about 0 view only
        if self._uncommitted_view_no == 0:
            return

        # 5. If audit has just 1 txn for the current view (and this view >0), then
        # get the last view from the pool ledger
        if last_txn_in_prev_view is None:
            # assume last view=0 if we don't know it
            self.node_reg_at_beginning_of_view[0] = list(
                self.__load_node_reg_for_first_audit_txn(
                    first_txn_in_this_view))
            return

        # 6. Get the first audit txn for the last view
        first_txn_in_last_view, _ = self.__get_first_txn_in_view_from_audit(
            audit_ledger, last_txn_in_prev_view)

        # 7. load the last view node reg (either from audit ledger or
        # the pool one if first_txn_in_last_view is the first txn in audit ledger)
        last_view_no = get_payload_data(
            first_txn_in_last_view)[AUDIT_TXN_VIEW_NO]
        self.node_reg_at_beginning_of_view[last_view_no] = list(
            self.__load_node_reg_for_view(audit_ledger,
                                          first_txn_in_last_view))

    def __load_node_reg_from_pool_ledger(self, to=None):
        node_reg = []
        for _, txn in self.ledger.getAllTxn(to=to):
            if get_type(txn) != NODE:
                continue
            txn_data = get_payload_data(txn)
            node_name = txn_data[DATA][ALIAS]
            services = txn_data[DATA].get(SERVICES)

            if services is None:
                continue

            if node_name not in node_reg and VALIDATOR in services:
                # new node added or old one promoted
                node_reg.append(node_name)
            elif node_name in node_reg and VALIDATOR not in services:
                # existing node demoted
                node_reg.remove(node_name)
        return node_reg

    # TODO: create a helper class to get data from Audit Ledger
    def __load_current_node_reg_from_audit_ledger(self):
        audit_ledger = self.database_manager.get_ledger(AUDIT_LEDGER_ID)
        if not audit_ledger:
            return None

        last_txn = audit_ledger.get_last_committed_txn()
        last_txn_node_reg = get_payload_data(last_txn).get(AUDIT_TXN_NODE_REG)
        if last_txn_node_reg is None:
            return None

        if isinstance(last_txn_node_reg, int):
            seq_no = get_seq_no(last_txn) - last_txn_node_reg
            audit_txn_for_seq_no = audit_ledger.getBySeqNo(seq_no)
            last_txn_node_reg = get_payload_data(audit_txn_for_seq_no).get(
                AUDIT_TXN_NODE_REG)

        if last_txn_node_reg is None:
            return None
        return last_txn_node_reg

    def __load_node_reg_for_view(self, audit_ledger, audit_txn):
        txn_seq_no = get_seq_no(audit_txn)
        audit_txn_data = get_payload_data(audit_txn)

        # If this is the first txn in the audit ledger, so that we don't know a full history,
        # then get node reg from the pool ledger
        if txn_seq_no <= 1:
            return self.__load_node_reg_for_first_audit_txn(audit_txn)

        # Get the node reg from audit txn
        node_reg = audit_txn_data.get(AUDIT_TXN_NODE_REG)
        if node_reg is None:
            return self.__load_node_reg_for_first_audit_txn(audit_txn)

        if isinstance(node_reg, int):
            seq_no = get_seq_no(audit_txn) - node_reg
            prev_audit_txn = audit_ledger.getBySeqNo(seq_no)
            node_reg = get_payload_data(prev_audit_txn).get(AUDIT_TXN_NODE_REG)

        if node_reg is None:
            return self.__load_node_reg_for_first_audit_txn(audit_txn)

        return node_reg

    def __get_first_txn_in_view_from_audit(self, audit_ledger,
                                           this_view_first_txn):
        '''
        :param audit_ledger: audit ledger
        :param this_view_first_txn: a txn from the current view
        :return: the first txn in this view and the last txn in the previous view (if amy, otherwise None)
        '''
        this_txn_view_no = get_payload_data(this_view_first_txn).get(
            AUDIT_TXN_VIEW_NO)

        prev_view_last_txn = None
        while True:
            txn_primaries = get_payload_data(this_view_first_txn).get(
                AUDIT_TXN_PRIMARIES)
            if isinstance(txn_primaries, int):
                seq_no = get_seq_no(this_view_first_txn) - txn_primaries
                this_view_first_txn = audit_ledger.getBySeqNo(seq_no)
            this_txn_seqno = get_seq_no(this_view_first_txn)
            if this_txn_seqno <= 1:
                break
            prev_view_last_txn = audit_ledger.getBySeqNo(this_txn_seqno - 1)
            prev_txn_view_no = get_payload_data(prev_view_last_txn).get(
                AUDIT_TXN_VIEW_NO)

            if this_txn_view_no != prev_txn_view_no:
                break

            this_view_first_txn = prev_view_last_txn
            prev_view_last_txn = None

        return this_view_first_txn, prev_view_last_txn

    def __load_node_reg_for_first_audit_txn(self, audit_txn):
        # If this is the first txn in the audit ledger, so that we don't know a full history,
        # then get node reg from the pool ledger
        audit_txn_data = get_payload_data(audit_txn)
        genesis_pool_ledger_size = audit_txn_data[AUDIT_TXN_LEDGERS_SIZE][
            POOL_LEDGER_ID]
        return self.__load_node_reg_from_pool_ledger(
            to=genesis_pool_ledger_size)